YARN-9015. [DevicePlugin] Add an interface for device plugin to provide customized scheduler. (Zhankun Tang via wangda)

Change-Id: Ib2e4ae47a6f29bb3082c1f8520cf5a52ca720979
This commit is contained in:
Wangda Tan 2018-12-12 11:44:22 -08:00
parent c771fe6e10
commit 61bdcb7b2b
8 changed files with 300 additions and 14 deletions

View File

@ -0,0 +1,38 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin;
import java.util.Set;
/**
* An optional interface to implement if custom device scheduling is needed.
* If this is not implemented, the device framework will do scheduling.
* */
public interface DevicePluginScheduler {
/**
* Called when allocating devices. The framework will do all device book
* keeping and fail recovery. So this hook could be stateless and only do
* scheduling based on available devices passed in. It could be
* invoked multiple times by the framework.
* @param availableDevices Devices allowed to be chosen from.
* @param count Number of device to be allocated.
* @return A set of {@link Device} allocated
* */
Set<Device> allocateDevices(Set<Device> availableDevices, int count);
}

View File

@ -29,6 +29,7 @@ import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.exceptions.YarnRuntimeException; import org.apache.hadoop.yarn.exceptions.YarnRuntimeException;
import org.apache.hadoop.yarn.server.nodemanager.Context; import org.apache.hadoop.yarn.server.nodemanager.Context;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework.DeviceMappingManager; import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework.DeviceMappingManager;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework.DevicePluginAdapter; import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework.DevicePluginAdapter;
@ -129,9 +130,12 @@ public class ResourcePluginManager {
Configuration configuration, Configuration configuration,
Map<String, ResourcePlugin> pluginMap) Map<String, ResourcePlugin> pluginMap)
throws YarnRuntimeException, ClassNotFoundException { throws YarnRuntimeException, ClassNotFoundException {
LOG.info("The pluggable device framework enabled," + LOG.info("The pluggable device framework enabled,"
"trying to load the vendor plugins"); + "trying to load the vendor plugins");
deviceMappingManager = new DeviceMappingManager(context); if (null == deviceMappingManager) {
LOG.debug("DeviceMappingManager initialized.");
deviceMappingManager = new DeviceMappingManager(context);
}
String[] pluginClassNames = configuration.getStrings( String[] pluginClassNames = configuration.getStrings(
YarnConfiguration.NM_PLUGGABLE_DEVICE_FRAMEWORK_DEVICE_CLASSES); YarnConfiguration.NM_PLUGGABLE_DEVICE_FRAMEWORK_DEVICE_CLASSES);
if (null == pluginClassNames) { if (null == pluginClassNames) {
@ -193,6 +197,19 @@ public class ResourcePluginManager {
LOG.info("Adapter of {} init success!", pluginClassName); LOG.info("Adapter of {} init success!", pluginClassName);
// Store plugin as adapter instance // Store plugin as adapter instance
pluginMap.put(request.getResourceName(), pluginAdapter); pluginMap.put(request.getResourceName(), pluginAdapter);
// If the device plugin implements DevicePluginScheduler interface
if (dpInstance instanceof DevicePluginScheduler) {
// check DevicePluginScheduler interface compatibility
checkInterfaceCompatibility(DevicePluginScheduler.class, pluginClazz);
LOG.info(
"{} can schedule {} devices."
+ "Added as preferred device plugin scheduler",
pluginClassName,
resourceName);
deviceMappingManager.addDevicePluginScheduler(
resourceName,
(DevicePluginScheduler) dpInstance);
}
} // end for } // end for
} }
@ -243,6 +260,12 @@ public class ResourcePluginManager {
return true; return true;
} }
@VisibleForTesting
public void setDeviceMappingManager(
DeviceMappingManager deviceMappingManager) {
this.deviceMappingManager = deviceMappingManager;
}
public DeviceMappingManager getDeviceMappingManager() { public DeviceMappingManager getDeviceMappingManager() {
return deviceMappingManager; return deviceMappingManager;
} }

View File

@ -29,18 +29,20 @@ import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.exceptions.ResourceNotFoundException; import org.apache.hadoop.yarn.exceptions.ResourceNotFoundException;
import org.apache.hadoop.yarn.server.nodemanager.Context; import org.apache.hadoop.yarn.server.nodemanager.Context;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container; import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException; import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
import java.io.IOException; import java.io.IOException;
import java.io.Serializable; import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.TreeSet; import java.util.TreeSet;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Collections;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
/** /**
@ -54,6 +56,10 @@ public class DeviceMappingManager {
private Context nmContext; private Context nmContext;
private static final int WAIT_MS_PER_LOOP = 1000; private static final int WAIT_MS_PER_LOOP = 1000;
// Holds vendor implemented scheduler
private Map<String, DevicePluginScheduler> devicePluginSchedulers =
new ConcurrentHashMap<>();
/** /**
* Hold all type of devices. * Hold all type of devices.
* key is the device resource name * key is the device resource name
@ -84,6 +90,11 @@ public class DeviceMappingManager {
return allUsedDevices; return allUsedDevices;
} }
@VisibleForTesting
public Map<String, DevicePluginScheduler> getDevicePluginSchedulers() {
return devicePluginSchedulers;
}
public synchronized void addDeviceSet(String resourceName, public synchronized void addDeviceSet(String resourceName,
Set<Device> deviceSet) { Set<Device> deviceSet) {
LOG.info("Adding new resource: " + "type:" LOG.info("Adding new resource: " + "type:"
@ -162,9 +173,10 @@ public class DeviceMappingManager {
Set<Device> assignedDevices = new TreeSet<>(); Set<Device> assignedDevices = new TreeSet<>();
Map<Device, ContainerId> usedDevices = allUsedDevices.get(resourceName); Map<Device, ContainerId> usedDevices = allUsedDevices.get(resourceName);
Set<Device> allowedDevices = allAllowedDevices.get(resourceName); Set<Device> allowedDevices = allAllowedDevices.get(resourceName);
DevicePluginScheduler dps = devicePluginSchedulers.get(resourceName);
defaultScheduleAction(allowedDevices, usedDevices, // Prefer DevicePluginScheduler logic
assignedDevices, containerId, requestedDeviceCount); pickAndDoSchedule(allowedDevices, usedDevices, assignedDevices,
containerId, requestedDeviceCount, resourceName, dps);
// Record in state store if we allocated anything // Record in state store if we allocated anything
if (!assignedDevices.isEmpty()) { if (!assignedDevices.isEmpty()) {
@ -273,7 +285,46 @@ public class DeviceMappingManager {
return releasingDevices; return releasingDevices;
} }
// default scheduling logic /**
* If device plugin has own scheduler, then use it.
* Otherwise, pick our default scheduler to do scheduling.
* */
private void pickAndDoSchedule(Set<Device> allowed,
Map<Device, ContainerId> used, Set<Device> assigned,
ContainerId containerId, int count, String resourceName,
DevicePluginScheduler dps) throws ResourceHandlerException {
if (null == dps) {
LOG.debug("Customized device plugin scheduler is preferred "
+ "but not implemented, use default logic");
defaultScheduleAction(allowed, used,
assigned, containerId, count);
} else {
LOG.debug("Customized device plugin implemented,"
+ "use customized logic");
// Use customized device scheduler
LOG.debug("Try to schedule " + count
+ "(" + resourceName + ") using " + dps.getClass());
// Pass in unmodifiable set
Set<Device> dpsAllocated = dps.allocateDevices(
Sets.difference(allowed, used.keySet()),
count);
if (dpsAllocated.size() != count) {
throw new ResourceHandlerException(dps.getClass()
+ " should allocate " + count
+ " of " + resourceName + ", but actual: "
+ assigned.size());
}
// copy
assigned.addAll(dpsAllocated);
// Store assigned devices into usedDevices
for (Device device : assigned) {
used.put(device, containerId);
}
}
}
// Default scheduling logic
private void defaultScheduleAction(Set<Device> allowed, private void defaultScheduleAction(Set<Device> allowed,
Map<Device, ContainerId> used, Set<Device> assigned, Map<Device, ContainerId> used, Set<Device> assigned,
ContainerId containerId, int count) { ContainerId containerId, int count) {
@ -307,7 +358,6 @@ public class DeviceMappingManager {
} }
} }
public Set<Device> getAllowed() { public Set<Device> getAllowed() {
return allowed; return allowed;
} }
@ -321,4 +371,11 @@ public class DeviceMappingManager {
} }
@VisibleForTesting
public synchronized void addDevicePluginScheduler(String resourceName,
DevicePluginScheduler s) {
this.devicePluginSchedulers.put(resourceName,
Objects.requireNonNull(s));
}
} }

View File

@ -44,8 +44,8 @@ public class DevicePluginAdapter implements ResourcePlugin {
private final String resourceName; private final String resourceName;
private final DevicePlugin devicePlugin; private final DevicePlugin devicePlugin;
private DeviceMappingManager deviceMappingManager; private DeviceMappingManager deviceMappingManager;
private DeviceResourceUpdaterImpl deviceResourceUpdater;
private DeviceResourceHandlerImpl deviceResourceHandler; private DeviceResourceHandlerImpl deviceResourceHandler;
private DeviceResourceUpdaterImpl deviceResourceUpdater;
public DevicePluginAdapter(String name, DevicePlugin dp, public DevicePluginAdapter(String name, DevicePlugin dp,
DeviceMappingManager dmm) { DeviceMappingManager dmm) {

View File

@ -35,6 +35,7 @@ import org.apache.hadoop.yarn.server.nodemanager.NodeManager;
import org.apache.hadoop.yarn.server.nodemanager.NodeManagerTestBase; import org.apache.hadoop.yarn.server.nodemanager.NodeManagerTestBase;
import org.apache.hadoop.yarn.server.nodemanager.NodeStatusUpdater; import org.apache.hadoop.yarn.server.nodemanager.NodeStatusUpdater;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.ContainerManagerImpl; import org.apache.hadoop.yarn.server.nodemanager.containermanager.ContainerManagerImpl;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container; import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperation; import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperation;
@ -494,6 +495,33 @@ public class TestResourcePluginManager extends NodeManagerTestBase {
Assert.assertEquals(expectedMessage, actualMessage); Assert.assertEquals(expectedMessage, actualMessage);
} }
@Test
public void testLoadPluginWithCustomizedScheduler() {
ResourcePluginManager rpm = new ResourcePluginManager();
DeviceMappingManager dmm = new DeviceMappingManager(mock(Context.class));
DeviceMappingManager dmmSpy = spy(dmm);
ResourcePluginManager rpmSpy = spy(rpm);
rpmSpy.setDeviceMappingManager(dmmSpy);
nm = new MyMockNM(rpmSpy);
conf.setBoolean(YarnConfiguration.NM_PLUGGABLE_DEVICE_FRAMEWORK_ENABLED,
true);
conf.setStrings(
YarnConfiguration.NM_PLUGGABLE_DEVICE_FRAMEWORK_DEVICE_CLASSES,
FakeTestDevicePlugin1.class.getCanonicalName()
+ "," + FakeTestDevicePlugin5.class.getCanonicalName());
nm.init(conf);
nm.start();
// only 1 plugin has the customized scheduler
verify(rpmSpy, times(1)).checkInterfaceCompatibility(
DevicePlugin.class, FakeTestDevicePlugin1.class);
verify(dmmSpy, times(1)).addDevicePluginScheduler(
any(String.class), any(DevicePluginScheduler.class));
Assert.assertEquals(1, dmm.getDevicePluginSchedulers().size());
}
@Test(timeout = 30000) @Test(timeout = 30000)
public void testRequestedResourceNameIsConfigured() public void testRequestedResourceNameIsConfigured()
throws Exception{ throws Exception{

View File

@ -27,7 +27,8 @@ import java.util.TreeSet;
* Used only for testing. * Used only for testing.
* A fake normal vendor plugin * A fake normal vendor plugin
* */ * */
public class FakeTestDevicePlugin1 implements DevicePlugin { public class FakeTestDevicePlugin1
implements DevicePlugin, DevicePluginScheduler{
@Override @Override
public DeviceRegisterRequest getRegisterRequestInfo() { public DeviceRegisterRequest getRegisterRequestInfo() {
return DeviceRegisterRequest.Builder.newInstance() return DeviceRegisterRequest.Builder.newInstance()
@ -58,4 +59,19 @@ public class FakeTestDevicePlugin1 implements DevicePlugin {
public void onDevicesReleased(Set<Device> allocatedDevices) { public void onDevicesReleased(Set<Device> allocatedDevices) {
} }
@Override
public Set<Device> allocateDevices(Set<Device> availableDevices,
int count) {
Set<Device> allocated = new TreeSet<Device>();
int number = 0;
for (Device d : availableDevices) {
allocated.add(d);
number++;
if (number == count) {
break;
}
}
return allocated;
}
} }

View File

@ -0,0 +1,56 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
import java.util.Set;
/**
* A normal plugin without customized scheduler.
*/
public class FakeTestDevicePlugin5 implements DevicePlugin {
@Override
public DeviceRegisterRequest getRegisterRequestInfo() throws Exception {
return DeviceRegisterRequest.Builder.newInstance()
.setResourceName("cmp.com/cmp").build();
}
@Override
public Set<Device> getDevices() throws Exception {
return null;
}
@Override
public DeviceRuntimeSpec onDevicesAllocated(Set<Device> allocatedDevices,
YarnRuntimeType yarnRuntime) throws Exception {
return null;
}
@Override
public void onDevicesReleased(Set<Device> releasedDevices) throws Exception {
}
}

View File

@ -31,6 +31,7 @@ import org.apache.hadoop.yarn.server.nodemanager.Context;
import org.apache.hadoop.yarn.server.nodemanager.NodeManager; import org.apache.hadoop.yarn.server.nodemanager.NodeManager;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType; import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
@ -63,6 +64,8 @@ import java.util.Set;
import java.util.TreeSet; import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.isA; import static org.mockito.Matchers.isA;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
@ -439,6 +442,56 @@ public class TestDevicePluginAdapter {
} }
@Test
public void testPreferPluginScheduler() throws IOException, YarnException {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService storeService = mock(NMStateStoreService.class);
when(context.getNMStateStore()).thenReturn(storeService);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
// Init scheduler manager
DeviceMappingManager dmm = new DeviceMappingManager(context);
ResourcePluginManager rpm = mock(ResourcePluginManager.class);
when(rpm.getDeviceMappingManager()).thenReturn(dmm);
// Init an plugin
MyPlugin plugin = new MyPlugin();
MyPlugin spyPlugin = spy(plugin);
String resourceName = MyPlugin.RESOURCE_NAME;
// Add plugin to DeviceMappingManager
dmm.getDevicePluginSchedulers().put(MyPlugin.RESOURCE_NAME, spyPlugin);
// Init an adapter for the plugin
DevicePluginAdapter adapter = new DevicePluginAdapter(
resourceName,
spyPlugin, dmm);
// Bootstrap, adding device
adapter.initialize(context);
adapter.createResourceHandler(context,
mockCGroupsHandler, mockPrivilegedExecutor);
adapter.getDeviceResourceHandler().bootstrap(conf);
int size = dmm.getAvailableDevices(resourceName);
Assert.assertEquals(3, size);
// A container c1 requests 1 device
Container c1 = mockContainerWithDeviceRequest(0,
resourceName,
1, false);
// preStart
adapter.getDeviceResourceHandler().preStart(c1);
// Use customized scheduler
verify(spyPlugin, times(1)).allocateDevices(
any(TreeSet.class), anyInt());
Assert.assertEquals(2,
dmm.getAvailableDevices(resourceName));
Assert.assertEquals(1,
dmm.getAllUsedDevices().get(resourceName).size());
Assert.assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
}
private static Container mockContainerWithDeviceRequest(int id, private static Container mockContainerWithDeviceRequest(int id,
String resourceName, String resourceName,
int numDeviceRequest, int numDeviceRequest,
@ -469,7 +522,7 @@ public class TestDevicePluginAdapter {
.newInstance(ApplicationId.newInstance(1234L, 1), 1), id); .newInstance(ApplicationId.newInstance(1234L, 1), 1), id);
} }
private class MyPlugin implements DevicePlugin { private class MyPlugin implements DevicePlugin, DevicePluginScheduler {
private final static String RESOURCE_NAME = "cmpA.com/hdwA"; private final static String RESOURCE_NAME = "cmpA.com/hdwA";
@Override @Override
public DeviceRegisterRequest getRegisterRequestInfo() { public DeviceRegisterRequest getRegisterRequestInfo() {
@ -518,6 +571,21 @@ public class TestDevicePluginAdapter {
public void onDevicesReleased(Set<Device> releasedDevices) { public void onDevicesReleased(Set<Device> releasedDevices) {
} }
@Override
public Set<Device> allocateDevices(Set<Device> availableDevices,
int count) {
Set<Device> allocated = new TreeSet<>();
int number = 0;
for (Device d : availableDevices) {
allocated.add(d);
number++;
if (number == count) {
break;
}
}
return allocated;
}
} // MyPlugin } // MyPlugin
} }