From 97b7d39a5f5a18c0ebca00d1c74d4bc0314024e2 Mon Sep 17 00:00:00 2001 From: Carlo Curino Date: Thu, 12 Oct 2017 10:38:58 -0700 Subject: [PATCH] YARN-7317. Fix overallocation resulted from ceiling in LocalityMulticastAMRMProxyPolicy. (contributed by Botong Huang via curino) (cherry picked from commit 13fcfb3d46ee7a0d606b4bb221d1cd66ef2a5a7c) --- .../policies/FederationPolicyUtils.java | 41 ++++++- .../LocalityMulticastAMRMProxyPolicy.java | 101 +++++++++++++--- .../router/WeightedRandomRouterPolicy.java | 33 ++---- .../policies/TestFederationPolicyUtils.java | 58 +++++++++ .../TestLocalityMulticastAMRMProxyPolicy.java | 110 ++++++++++++++---- 5 files changed, 278 insertions(+), 65 deletions(-) create mode 100644 hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/federation/policies/TestFederationPolicyUtils.java diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/federation/policies/FederationPolicyUtils.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/federation/policies/FederationPolicyUtils.java index 7716a6f7174..aaa2c43c6ae 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/federation/policies/FederationPolicyUtils.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/federation/policies/FederationPolicyUtils.java @@ -19,7 +19,9 @@ package org.apache.hadoop.yarn.server.federation.policies; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.List; +import java.util.Random; import org.apache.hadoop.classification.InterfaceAudience.Private; import org.apache.hadoop.conf.Configuration; @@ -46,6 +48,8 @@ public final class FederationPolicyUtils { public static final String NO_ACTIVE_SUBCLUSTER_AVAILABLE = "No active SubCluster available to submit the request."; + private static final Random RAND = new Random(System.currentTimeMillis()); + /** Disable constructor. */ private FederationPolicyUtils() { } @@ -200,4 +204,39 @@ public final class FederationPolicyUtils { FederationPolicyUtils.NO_ACTIVE_SUBCLUSTER_AVAILABLE); } -} \ No newline at end of file + /** + * Select a random bin according to the weight array for the bins. Only bins + * with positive weights will be considered. If no positive weight found, + * return -1. + * + * @param weights the weight array + * @return the index of the sample in the array + */ + public static int getWeightedRandom(ArrayList weights) { + int i; + float totalWeight = 0; + for (i = 0; i < weights.size(); i++) { + if (weights.get(i) > 0) { + totalWeight += weights.get(i); + } + } + if (totalWeight == 0) { + return -1; + } + float samplePoint = RAND.nextFloat() * totalWeight; + int lastIndex = 0; + for (i = 0; i < weights.size(); i++) { + if (weights.get(i) > 0) { + if (samplePoint <= weights.get(i)) { + return i; + } else { + lastIndex = i; + samplePoint -= weights.get(i); + } + } + } + // This can only happen if samplePoint is very close to totoalWeight and + // float rounding kicks in during subtractions + return lastIndex; + } +} diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/federation/policies/amrmproxy/LocalityMulticastAMRMProxyPolicy.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/federation/policies/amrmproxy/LocalityMulticastAMRMProxyPolicy.java index 454962f63f6..da30d98ff58 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/federation/policies/amrmproxy/LocalityMulticastAMRMProxyPolicy.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/federation/policies/amrmproxy/LocalityMulticastAMRMProxyPolicy.java @@ -34,7 +34,9 @@ import org.apache.hadoop.yarn.api.records.ResourceRequest; import org.apache.hadoop.yarn.exceptions.YarnException; import org.apache.hadoop.yarn.exceptions.YarnRuntimeException; import org.apache.hadoop.yarn.server.federation.policies.FederationPolicyInitializationContext; +import org.apache.hadoop.yarn.server.federation.policies.FederationPolicyUtils; import org.apache.hadoop.yarn.server.federation.policies.dao.WeightedPolicyInfo; +import org.apache.hadoop.yarn.server.federation.policies.exceptions.FederationPolicyException; import org.apache.hadoop.yarn.server.federation.policies.exceptions.FederationPolicyInitializationException; import org.apache.hadoop.yarn.server.federation.policies.exceptions.NoActiveSubclustersException; import org.apache.hadoop.yarn.server.federation.resolver.SubClusterResolver; @@ -45,6 +47,7 @@ import org.apache.hadoop.yarn.server.federation.utils.FederationStateStoreFacade import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; /** @@ -314,25 +317,33 @@ public class LocalityMulticastAMRMProxyPolicy extends AbstractAMRMProxyPolicy { */ private void splitIndividualAny(ResourceRequest originalResourceRequest, Set targetSubclusters, - AllocationBookkeeper allocationBookkeeper) { + AllocationBookkeeper allocationBookkeeper) throws YarnException { long allocationId = originalResourceRequest.getAllocationRequestId(); + int numContainer = originalResourceRequest.getNumContainers(); - for (SubClusterId targetId : targetSubclusters) { - float numContainer = originalResourceRequest.getNumContainers(); - - // If the ANY request has 0 containers to begin with we must forward it to - // any RM we have previously contacted (this might be the user way - // to cancel a previous request). - if (numContainer == 0 && headroom.containsKey(targetId)) { - allocationBookkeeper.addAnyRR(targetId, originalResourceRequest); + // If the ANY request has 0 containers to begin with we must forward it to + // any RM we have previously contacted (this might be the user way + // to cancel a previous request). + if (numContainer == 0) { + for (SubClusterId targetId : targetSubclusters) { + if (headroom.containsKey(targetId)) { + allocationBookkeeper.addAnyRR(targetId, originalResourceRequest); + } } + return; + } + // List preserves iteration order + List targetSCs = new ArrayList<>(targetSubclusters); + + // Compute the distribution weights + ArrayList weightsList = new ArrayList<>(); + for (SubClusterId targetId : targetSCs) { // If ANY is associated with localized asks, split based on their ratio if (allocationBookkeeper.getSubClustersForId(allocationId) != null) { - float localityBasedWeight = getLocalityBasedWeighting(allocationId, - targetId, allocationBookkeeper); - numContainer = numContainer * localityBasedWeight; + weightsList.add(getLocalityBasedWeighting(allocationId, targetId, + allocationBookkeeper)); } else { // split ANY based on load and policy configuration float headroomWeighting = @@ -340,12 +351,18 @@ public class LocalityMulticastAMRMProxyPolicy extends AbstractAMRMProxyPolicy { float policyWeighting = getPolicyConfigWeighting(targetId, allocationBookkeeper); // hrAlpha controls how much headroom influencing decision - numContainer = numContainer - * (hrAlpha * headroomWeighting + (1 - hrAlpha) * policyWeighting); + weightsList + .add(hrAlpha * headroomWeighting + (1 - hrAlpha) * policyWeighting); } + } + // Compute the integer container counts for each sub-cluster + ArrayList containerNums = + computeIntegerAssignment(numContainer, weightsList); + int i = 0; + for (SubClusterId targetId : targetSCs) { // if the calculated request is non-empty add it to the answer - if (numContainer > 0) { + if (containerNums.get(i) > 0) { ResourceRequest out = ResourceRequest.newInstance(originalResourceRequest.getPriority(), originalResourceRequest.getResourceName(), @@ -355,16 +372,68 @@ public class LocalityMulticastAMRMProxyPolicy extends AbstractAMRMProxyPolicy { originalResourceRequest.getNodeLabelExpression(), originalResourceRequest.getExecutionTypeRequest()); out.setAllocationRequestId(allocationId); - out.setNumContainers((int) Math.ceil(numContainer)); + out.setNumContainers(containerNums.get(i)); if (ResourceRequest.isAnyLocation(out.getResourceName())) { allocationBookkeeper.addAnyRR(targetId, out); } else { allocationBookkeeper.addRackRR(targetId, out); } } + i++; } } + /** + * Split the integer into bins according to the weights. + * + * @param totalNum total number of containers to split + * @param weightsList the weights for each subcluster + * @return the container allocation after split + * @throws YarnException if fails + */ + @VisibleForTesting + protected ArrayList computeIntegerAssignment(int totalNum, + ArrayList weightsList) throws YarnException { + int i, residue; + ArrayList ret = new ArrayList<>(); + float totalWeight = 0, totalNumFloat = totalNum; + + if (weightsList.size() == 0) { + return ret; + } + for (i = 0; i < weightsList.size(); i++) { + ret.add(0); + if (weightsList.get(i) > 0) { + totalWeight += weightsList.get(i); + } + } + if (totalWeight == 0) { + StringBuilder sb = new StringBuilder(); + for (Float weight : weightsList) { + sb.append(weight + ", "); + } + throw new FederationPolicyException( + "No positive value found in weight array " + sb.toString()); + } + + // First pass, do flooring for all bins + residue = totalNum; + for (i = 0; i < weightsList.size(); i++) { + if (weightsList.get(i) > 0) { + int base = (int) (totalNumFloat * weightsList.get(i) / totalWeight); + ret.set(i, ret.get(i) + base); + residue -= base; + } + } + + // By now residue < weights.length, assign one a time + for (i = 0; i < residue; i++) { + int index = FederationPolicyUtils.getWeightedRandom(weightsList); + ret.set(index, ret.get(index) + 1); + } + return ret; + } + /** * Compute the weight to assign to a subcluster based on how many local * requests a subcluster is target of. diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/federation/policies/router/WeightedRandomRouterPolicy.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/federation/policies/router/WeightedRandomRouterPolicy.java index aec75760414..b1434104836 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/federation/policies/router/WeightedRandomRouterPolicy.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/main/java/org/apache/hadoop/yarn/server/federation/policies/router/WeightedRandomRouterPolicy.java @@ -21,16 +21,14 @@ package org.apache.hadoop.yarn.server.federation.policies.router; import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Random; import org.apache.hadoop.yarn.api.records.ApplicationSubmissionContext; import org.apache.hadoop.yarn.exceptions.YarnException; import org.apache.hadoop.yarn.server.federation.policies.FederationPolicyUtils; +import org.apache.hadoop.yarn.server.federation.policies.exceptions.FederationPolicyException; import org.apache.hadoop.yarn.server.federation.store.records.SubClusterId; import org.apache.hadoop.yarn.server.federation.store.records.SubClusterIdInfo; import org.apache.hadoop.yarn.server.federation.store.records.SubClusterInfo; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * This policy implements a weighted random sample among currently active @@ -38,10 +36,6 @@ import org.slf4j.LoggerFactory; */ public class WeightedRandomRouterPolicy extends AbstractRouterPolicy { - private static final Logger LOG = - LoggerFactory.getLogger(WeightedRandomRouterPolicy.class); - private Random rand = new Random(System.currentTimeMillis()); - @Override public SubClusterId getHomeSubcluster( ApplicationSubmissionContext appSubmissionContext, @@ -63,32 +57,25 @@ public class WeightedRandomRouterPolicy extends AbstractRouterPolicy { Map weights = getPolicyInfo().getRouterPolicyWeights(); - float totActiveWeight = 0; + ArrayList weightList = new ArrayList<>(); + ArrayList scIdList = new ArrayList<>(); for (Map.Entry entry : weights.entrySet()) { if (blacklist != null && blacklist.contains(entry.getKey().toId())) { continue; } if (entry.getKey() != null && activeSubclusters.containsKey(entry.getKey().toId())) { - totActiveWeight += entry.getValue(); + weightList.add(entry.getValue()); + scIdList.add(entry.getKey().toId()); } } - float lookupValue = rand.nextFloat() * totActiveWeight; - for (SubClusterId id : activeSubclusters.keySet()) { - if (blacklist != null && blacklist.contains(id)) { - continue; - } - SubClusterIdInfo idInfo = new SubClusterIdInfo(id); - if (weights.containsKey(idInfo)) { - lookupValue -= weights.get(idInfo); - } - if (lookupValue <= 0) { - return id; - } + int pickedIndex = FederationPolicyUtils.getWeightedRandom(weightList); + if (pickedIndex == -1) { + throw new FederationPolicyException( + "No positive weight found on active subclusters"); } - // should never happen - return null; + return scIdList.get(pickedIndex); } } diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/federation/policies/TestFederationPolicyUtils.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/federation/policies/TestFederationPolicyUtils.java new file mode 100644 index 00000000000..d9609788d22 --- /dev/null +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/federation/policies/TestFederationPolicyUtils.java @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package org.apache.hadoop.yarn.server.federation.policies; + +import java.util.ArrayList; + +import org.junit.Assert; +import org.junit.Test; + +/** + * Unit test for {@link FederationPolicyUtils}. + */ +public class TestFederationPolicyUtils { + + @Test + public void testGetWeightedRandom() { + int i; + float[] weights = + new float[] {0, 0.1f, 0.2f, 0.2f, -0.1f, 0.1f, 0.2f, 0.1f, 0.1f}; + float[] expectedWeights = + new float[] {0, 0.1f, 0.2f, 0.2f, 0, 0.1f, 0.2f, 0.1f, 0.1f}; + int[] result = new int[weights.length]; + + ArrayList weightsList = new ArrayList<>(); + for (float weight : weights) { + weightsList.add(weight); + } + + int n = 10000000; + for (i = 0; i < n; i++) { + int sample = FederationPolicyUtils.getWeightedRandom(weightsList); + result[sample]++; + } + for (i = 0; i < weights.length; i++) { + double actualWeight = (float) result[i] / n; + System.out.println(i + " " + actualWeight); + Assert.assertTrue( + "Index " + i + " Actual weight: " + actualWeight + + " expected weight: " + expectedWeights[i], + Math.abs(actualWeight - expectedWeights[i]) < 0.01); + } + } +} diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/federation/policies/amrmproxy/TestLocalityMulticastAMRMProxyPolicy.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/federation/policies/amrmproxy/TestLocalityMulticastAMRMProxyPolicy.java index 6e3a2f14efe..46a60115017 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/federation/policies/amrmproxy/TestLocalityMulticastAMRMProxyPolicy.java +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-common/src/test/java/org/apache/hadoop/yarn/server/federation/policies/amrmproxy/TestLocalityMulticastAMRMProxyPolicy.java @@ -157,18 +157,20 @@ public class TestLocalityMulticastAMRMProxyPolicy validateSplit(response, resourceRequests); - // based on headroom, we expect 75 containers to got to subcluster0, - // as it advertise lots of headroom (100), no containers for sublcuster1 - // as it advertise zero headroom, 1 to subcluster 2 (as it advertise little - // headroom (1), and 25 to subcluster5 which has unknown headroom, and so - // it gets 1/4th of the load - checkExpectedAllocation(response, "subcluster0", 1, 75); + /* + * based on headroom, we expect 75 containers to got to subcluster0 (60) and + * subcluster2 (15) according to the advertised headroom (40 and 10), no + * containers for sublcuster1 as it advertise zero headroom, and 25 to + * subcluster5 which has unknown headroom, and so it gets 1/4th of the load + */ + checkExpectedAllocation(response, "subcluster0", 1, 60); checkExpectedAllocation(response, "subcluster1", 1, -1); - checkExpectedAllocation(response, "subcluster2", 1, 1); + checkExpectedAllocation(response, "subcluster2", 1, 15); checkExpectedAllocation(response, "subcluster5", 1, 25); + checkTotalContainerAllocation(response, 100); // notify a change in headroom and try again - AllocateResponse ar = getAllocateResponseWithTargetHeadroom(100); + AllocateResponse ar = getAllocateResponseWithTargetHeadroom(40); ((FederationAMRMProxyPolicy) getPolicy()) .notifyOfResponse(SubClusterId.newInstance("subcluster2"), ar); response = ((FederationAMRMProxyPolicy) getPolicy()) @@ -178,14 +180,16 @@ public class TestLocalityMulticastAMRMProxyPolicy prettyPrintRequests(response); validateSplit(response, resourceRequests); - // we simulated a change in headroom for subcluster2, which will now - // have the same headroom of subcluster0 and so it splits the requests - // note that the total is still less or equal to (userAsk + numSubClusters) - checkExpectedAllocation(response, "subcluster0", 1, 38); + /* + * we simulated a change in headroom for subcluster2, which will now have + * the same headroom of subcluster0, so each 37.5, note that the odd one + * will be assigned to either one of the two subclusters + */ + checkExpectedAllocation(response, "subcluster0", 1, 37); checkExpectedAllocation(response, "subcluster1", 1, -1); - checkExpectedAllocation(response, "subcluster2", 1, 38); + checkExpectedAllocation(response, "subcluster2", 1, 37); checkExpectedAllocation(response, "subcluster5", 1, 25); - + checkTotalContainerAllocation(response, 100); } @Test(timeout = 5000) @@ -250,6 +254,7 @@ public class TestLocalityMulticastAMRMProxyPolicy checkExpectedAllocation(response, "subcluster3", -1, -1); checkExpectedAllocation(response, "subcluster4", -1, -1); checkExpectedAllocation(response, "subcluster5", -1, -1); + checkTotalContainerAllocation(response, 0); } @Test @@ -276,19 +281,19 @@ public class TestLocalityMulticastAMRMProxyPolicy validateSplit(response, resourceRequests); // in this case the headroom allocates 50 containers, while weights allocate - // the rest. due to weights we have 12.5 (round to 13) containers for each + // the rest. due to weights we have 12.5 containers for each // sublcuster, the rest is due to headroom. - checkExpectedAllocation(response, "subcluster0", 1, 50); - checkExpectedAllocation(response, "subcluster1", 1, 13); - checkExpectedAllocation(response, "subcluster2", 1, 13); + checkExpectedAllocation(response, "subcluster0", 1, 42); // 30 + 12.5 + checkExpectedAllocation(response, "subcluster1", 1, 12); // 0 + 12.5 + checkExpectedAllocation(response, "subcluster2", 1, 20); // 7.5 + 12.5 checkExpectedAllocation(response, "subcluster3", -1, -1); checkExpectedAllocation(response, "subcluster4", -1, -1); - checkExpectedAllocation(response, "subcluster5", 1, 25); - + checkExpectedAllocation(response, "subcluster5", 1, 25); // 12.5 + 12.5 + checkTotalContainerAllocation(response, 100); } private void prepPolicyWithHeadroom() throws YarnException { - AllocateResponse ar = getAllocateResponseWithTargetHeadroom(100); + AllocateResponse ar = getAllocateResponseWithTargetHeadroom(40); ((FederationAMRMProxyPolicy) getPolicy()) .notifyOfResponse(SubClusterId.newInstance("subcluster0"), ar); @@ -296,7 +301,7 @@ public class TestLocalityMulticastAMRMProxyPolicy ((FederationAMRMProxyPolicy) getPolicy()) .notifyOfResponse(SubClusterId.newInstance("subcluster1"), ar); - ar = getAllocateResponseWithTargetHeadroom(1); + ar = getAllocateResponseWithTargetHeadroom(10); ((FederationAMRMProxyPolicy) getPolicy()) .notifyOfResponse(SubClusterId.newInstance("subcluster2"), ar); } @@ -363,6 +368,9 @@ public class TestLocalityMulticastAMRMProxyPolicy // subcluster5 should get only part of the request-id 2 broadcast checkExpectedAllocation(response, "subcluster5", 1, 20); + // Check the total number of container asks in all RR + checkTotalContainerAllocation(response, 130); + // check that the allocations that show up are what expected for (ResourceRequest rr : response.get(getHomeSubCluster())) { Assert.assertTrue( @@ -401,8 +409,8 @@ public class TestLocalityMulticastAMRMProxyPolicy // response should be null private void checkExpectedAllocation( Map> response, String subCluster, - long totResourceRequests, long totContainers) { - if (totContainers == -1) { + long totResourceRequests, long minimumTotalContainers) { + if (minimumTotalContainers == -1) { Assert.assertNull(response.get(SubClusterId.newInstance(subCluster))); } else { SubClusterId sc = SubClusterId.newInstance(subCluster); @@ -412,10 +420,25 @@ public class TestLocalityMulticastAMRMProxyPolicy for (ResourceRequest rr : response.get(sc)) { actualContCount += rr.getNumContainers(); } - Assert.assertEquals(totContainers, actualContCount); + Assert.assertTrue( + "Actual count " + actualContCount + " should be at least " + + minimumTotalContainers, + minimumTotalContainers <= actualContCount); } } + private void checkTotalContainerAllocation( + Map> response, long totalContainers) { + long actualContCount = 0; + for (Map.Entry> entry : response + .entrySet()) { + for (ResourceRequest rr : entry.getValue()) { + actualContCount += rr.getNumContainers(); + } + } + Assert.assertEquals(totalContainers, actualContCount); + } + private void validateSplit(Map> split, List original) throws YarnException { @@ -599,4 +622,41 @@ public class TestLocalityMulticastAMRMProxyPolicy return out; } + + public String printList(ArrayList list) { + StringBuilder sb = new StringBuilder(); + for (Integer entry : list) { + sb.append(entry + ", "); + } + return sb.toString(); + } + + @Test + public void testIntegerAssignment() throws YarnException { + float[] weights = + new float[] {0, 0.1f, 0.2f, 0.2f, -0.1f, 0.1f, 0.2f, 0.1f, 0.1f}; + int[] expectedMin = new int[] {0, 1, 3, 3, 0, 1, 3, 1, 1}; + ArrayList weightsList = new ArrayList<>(); + for (float weight : weights) { + weightsList.add(weight); + } + + LocalityMulticastAMRMProxyPolicy policy = + (LocalityMulticastAMRMProxyPolicy) getPolicy(); + for (int i = 0; i < 500000; i++) { + ArrayList allocations = + policy.computeIntegerAssignment(19, weightsList); + int sum = 0; + for (int j = 0; j < weights.length; j++) { + sum += allocations.get(j); + if (allocations.get(j) < expectedMin[j]) { + Assert.fail(allocations.get(j) + " at index " + j + + " should be at least " + expectedMin[j] + ". Allocation array: " + + printList(allocations)); + } + } + Assert.assertEquals( + "Expect sum to be 19 in array: " + printList(allocations), 19, sum); + } + } } \ No newline at end of file