diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/java/org/apache/hadoop/yarn/nodelabels/CommonNodeLabelsManager.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/java/org/apache/hadoop/yarn/nodelabels/CommonNodeLabelsManager.java index 4fd4bd6bd72..8fd0315fa85 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/java/org/apache/hadoop/yarn/nodelabels/CommonNodeLabelsManager.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/main/java/org/apache/hadoop/yarn/nodelabels/CommonNodeLabelsManager.java @@ -559,6 +559,50 @@ public class CommonNodeLabelsManager extends AbstractService { addNodeToLabels(node, newLabels); } + private void addLabelsToNodeInHost(NodeId node, Set labels) + throws IOException { + Host host = nodeCollections.get(node.getHost()); + if (null == host) { + throw new IOException("Cannot add labels to a host that " + + "does not exist. Create the host before adding labels to it."); + } + Node nm = host.nms.get(node); + if (nm != null) { + Node newNm = nm.copy(); + if (newNm.labels == null) { + newNm.labels = + Collections.newSetFromMap(new ConcurrentHashMap()); + } + newNm.labels.addAll(labels); + host.nms.put(node, newNm); + } + } + + protected void removeLabelsFromNodeInHost(NodeId node, Set labels) + throws IOException { + Host host = nodeCollections.get(node.getHost()); + if (null == host) { + throw new IOException("Cannot remove labels from a host that " + + "does not exist. Create the host before adding labels to it."); + } + Node nm = host.nms.get(node); + if (nm != null) { + if (nm.labels == null) { + nm.labels = new HashSet(); + } else { + nm.labels.removeAll(labels); + } + } + } + + private void replaceLabelsForNode(NodeId node, Set oldLabels, + Set newLabels) throws IOException { + if(oldLabels != null) { + removeLabelsFromNodeInHost(node, oldLabels); + } + addLabelsToNodeInHost(node, newLabels); + } + @SuppressWarnings("unchecked") protected void internalUpdateLabelsOnNodes( Map> nodeToLabels, NodeLabelUpdateOperation op) @@ -597,10 +641,12 @@ public class CommonNodeLabelsManager extends AbstractService { break; case REPLACE: replaceNodeForLabels(nodeId, host.labels, labels); + replaceLabelsForNode(nodeId, host.labels, labels); host.labels.clear(); host.labels.addAll(labels); for (Node node : host.nms.values()) { replaceNodeForLabels(node.nodeId, node.labels, labels); + replaceLabelsForNode(node.nodeId, node.labels, labels); node.labels = null; } break; @@ -625,6 +671,7 @@ public class CommonNodeLabelsManager extends AbstractService { case REPLACE: oldLabels = getLabelsByNode(nodeId); replaceNodeForLabels(nodeId, oldLabels, labels); + replaceLabelsForNode(nodeId, oldLabels, labels); if (nm.labels == null) { nm.labels = new HashSet(); } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/test/java/org/apache/hadoop/yarn/nodelabels/TestCommonNodeLabelsManager.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/test/java/org/apache/hadoop/yarn/nodelabels/TestCommonNodeLabelsManager.java index e86cca4e245..d9f9389866e 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/test/java/org/apache/hadoop/yarn/nodelabels/TestCommonNodeLabelsManager.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-common/src/test/java/org/apache/hadoop/yarn/nodelabels/TestCommonNodeLabelsManager.java @@ -616,4 +616,22 @@ public class TestCommonNodeLabelsManager extends NodeLabelTestBase { toNodeId("n1"), toSet(NodeLabel.newInstance("p2", true)), toNodeId("n2"), toSet(NodeLabel.newInstance("p3", false)))); } + + @Test(timeout = 5000) + public void testRemoveNodeLabelsInfo() throws IOException { + mgr.addToCluserNodeLabels(Arrays.asList(NodeLabel.newInstance("p1", true))); + mgr.addToCluserNodeLabels(Arrays.asList(NodeLabel.newInstance("p2", true))); + mgr.addLabelsToNode(ImmutableMap.of(toNodeId("n1:1"), toSet("p1"))); + mgr.replaceLabelsOnNode(ImmutableMap.of(toNodeId("n1"), toSet("p2"))); + + Map> labelsToNodes = mgr.getLabelsToNodes(); + assertLabelsToNodesEquals( + labelsToNodes, + ImmutableMap.of( + "p2", toSet(toNodeId("n1:1"), toNodeId("n1:0")))); + + mgr.replaceLabelsOnNode(ImmutableMap.of(toNodeId("n1"), new HashSet())); + Map> labelsToNodes2 = mgr.getLabelsToNodes(); + Assert.assertEquals(labelsToNodes2.get("p2"), null); + } }