YARN-1954. Added waitFor to AMRMClient(Async). Contributed by Tsuyoshi Ozawa.

git-svn-id: https://svn.apache.org/repos/asf/hadoop/common/trunk@1617002 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Zhijie Shen 2014-08-09 18:44:51 +00:00
parent a7643f4de7
commit ee3825e278
5 changed files with 237 additions and 3 deletions

View File

@ -103,6 +103,8 @@ Release 2.6.0 - UNRELEASED
YARN-2026. Fair scheduler: Consider only active queues for computing fairshare.
(Ashwin Shankar via kasha)
YARN-1954. Added waitFor to AMRMClient(Async). (Tsuyoshi Ozawa via zjshen)
OPTIMIZATIONS
BUG FIXES

View File

@ -22,6 +22,8 @@ import java.io.IOException;
import java.util.Collection;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceAudience.Private;
import org.apache.hadoop.classification.InterfaceAudience.Public;
@ -37,12 +39,14 @@ import org.apache.hadoop.yarn.client.api.impl.AMRMClientImpl;
import org.apache.hadoop.yarn.exceptions.YarnException;
import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
@InterfaceAudience.Public
@InterfaceStability.Stable
public abstract class AMRMClient<T extends AMRMClient.ContainerRequest> extends
AbstractService {
private static final Log LOG = LogFactory.getLog(AMRMClient.class);
/**
* Create a new instance of AMRMClient.
@ -336,4 +340,63 @@ public abstract class AMRMClient<T extends AMRMClient.ContainerRequest> extends
return nmTokenCache;
}
/**
* Wait for <code>check</code> to return true for each 1000 ms.
* See also {@link #waitFor(com.google.common.base.Supplier, int)}
* and {@link #waitFor(com.google.common.base.Supplier, int, int)}
* @param check
*/
public void waitFor(Supplier<Boolean> check) throws InterruptedException {
waitFor(check, 1000);
}
/**
* Wait for <code>check</code> to return true for each
* <code>checkEveryMillis</code> ms.
* See also {@link #waitFor(com.google.common.base.Supplier, int, int)}
* @param check user defined checker
* @param checkEveryMillis interval to call <code>check</code>
*/
public void waitFor(Supplier<Boolean> check, int checkEveryMillis)
throws InterruptedException {
waitFor(check, checkEveryMillis, 1);
}
/**
* Wait for <code>check</code> to return true for each
* <code>checkEveryMillis</code> ms. In the main loop, this method will log
* the message "waiting in main loop" for each <code>logInterval</code> times
* iteration to confirm the thread is alive.
* @param check user defined checker
* @param checkEveryMillis interval to call <code>check</code>
* @param logInterval interval to log for each
*/
public void waitFor(Supplier<Boolean> check, int checkEveryMillis,
int logInterval) throws InterruptedException {
Preconditions.checkNotNull(check, "check should not be null");
Preconditions.checkArgument(checkEveryMillis >= 0,
"checkEveryMillis should be positive value");
Preconditions.checkArgument(logInterval >= 0,
"logInterval should be positive value");
int loggingCounter = logInterval;
do {
if (LOG.isDebugEnabled()) {
LOG.debug("Check the condition for main loop.");
}
boolean result = check.get();
if (result) {
LOG.info("Exits the main loop.");
return;
}
if (--loggingCounter <= 0) {
LOG.info("Waiting in main loop.");
loggingCounter = logInterval;
}
Thread.sleep(checkEveryMillis);
} while (true);
}
}

View File

@ -18,11 +18,15 @@
package org.apache.hadoop.yarn.client.api.async;
import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.classification.InterfaceAudience.Private;
import org.apache.hadoop.classification.InterfaceAudience.Public;
import org.apache.hadoop.classification.InterfaceStability.Stable;
@ -90,6 +94,7 @@ import com.google.common.annotations.VisibleForTesting;
@Stable
public abstract class AMRMClientAsync<T extends ContainerRequest>
extends AbstractService {
private static final Log LOG = LogFactory.getLog(AMRMClientAsync.class);
protected final AMRMClient<T> client;
protected final CallbackHandler handler;
@ -189,6 +194,65 @@ extends AbstractService {
*/
public abstract int getClusterNodeCount();
/**
* Wait for <code>check</code> to return true for each 1000 ms.
* See also {@link #waitFor(com.google.common.base.Supplier, int)}
* and {@link #waitFor(com.google.common.base.Supplier, int, int)}
* @param check
*/
public void waitFor(Supplier<Boolean> check) throws InterruptedException {
waitFor(check, 1000);
}
/**
* Wait for <code>check</code> to return true for each
* <code>checkEveryMillis</code> ms.
* See also {@link #waitFor(com.google.common.base.Supplier, int, int)}
* @param check user defined checker
* @param checkEveryMillis interval to call <code>check</code>
*/
public void waitFor(Supplier<Boolean> check, int checkEveryMillis)
throws InterruptedException {
waitFor(check, checkEveryMillis, 1);
};
/**
* Wait for <code>check</code> to return true for each
* <code>checkEveryMillis</code> ms. In the main loop, this method will log
* the message "waiting in main loop" for each <code>logInterval</code> times
* iteration to confirm the thread is alive.
* @param check user defined checker
* @param checkEveryMillis interval to call <code>check</code>
* @param logInterval interval to log for each
*/
public void waitFor(Supplier<Boolean> check, int checkEveryMillis,
int logInterval) throws InterruptedException {
Preconditions.checkNotNull(check, "check should not be null");
Preconditions.checkArgument(checkEveryMillis >= 0,
"checkEveryMillis should be positive value");
Preconditions.checkArgument(logInterval >= 0,
"logInterval should be positive value");
int loggingCounter = logInterval;
do {
if (LOG.isDebugEnabled()) {
LOG.debug("Check the condition for main loop.");
}
boolean result = check.get();
if (result) {
LOG.info("Exits the main loop.");
return;
}
if (--loggingCounter <= 0) {
LOG.info("Waiting in main loop.");
loggingCounter = logInterval;
}
Thread.sleep(checkEveryMillis);
} while (true);
}
public interface CallbackHandler {
/**

View File

@ -18,6 +18,7 @@
package org.apache.hadoop.yarn.client.api.async.impl;
import com.google.common.base.Supplier;
import static org.mockito.Matchers.anyFloat;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Matchers.anyString;
@ -228,6 +229,41 @@ public class TestAMRMClientAsync {
asyncClient.stop();
}
@Test (timeout = 10000)
public void testAMRMClientAsyncShutDownWithWaitFor() throws Exception {
Configuration conf = new Configuration();
final TestCallbackHandler callbackHandler = new TestCallbackHandler();
@SuppressWarnings("unchecked")
AMRMClient<ContainerRequest> client = mock(AMRMClientImpl.class);
final AllocateResponse shutDownResponse = createAllocateResponse(
new ArrayList<ContainerStatus>(), new ArrayList<Container>(), null);
shutDownResponse.setAMCommand(AMCommand.AM_SHUTDOWN);
when(client.allocate(anyFloat())).thenReturn(shutDownResponse);
AMRMClientAsync<ContainerRequest> asyncClient =
AMRMClientAsync.createAMRMClientAsync(client, 10, callbackHandler);
asyncClient.init(conf);
asyncClient.start();
Supplier<Boolean> checker = new Supplier<Boolean>() {
@Override
public Boolean get() {
return callbackHandler.reboot;
}
};
asyncClient.registerApplicationMaster("localhost", 1234, null);
asyncClient.waitFor(checker);
asyncClient.stop();
// stopping should have joined all threads and completed all callbacks
Assert.assertTrue(callbackHandler.callbackCount == 0);
verify(client, times(1)).allocate(anyFloat());
asyncClient.stop();
}
@Test (timeout = 5000)
public void testCallAMRMClientAsyncStopFromCallbackHandler()
throws YarnException, IOException, InterruptedException {
@ -262,6 +298,40 @@ public class TestAMRMClientAsync {
}
}
@Test (timeout = 5000)
public void testCallAMRMClientAsyncStopFromCallbackHandlerWithWaitFor()
throws YarnException, IOException, InterruptedException {
Configuration conf = new Configuration();
final TestCallbackHandler2 callbackHandler = new TestCallbackHandler2();
@SuppressWarnings("unchecked")
AMRMClient<ContainerRequest> client = mock(AMRMClientImpl.class);
List<ContainerStatus> completed = Arrays.asList(
ContainerStatus.newInstance(newContainerId(0, 0, 0, 0),
ContainerState.COMPLETE, "", 0));
final AllocateResponse response = createAllocateResponse(completed,
new ArrayList<Container>(), null);
when(client.allocate(anyFloat())).thenReturn(response);
AMRMClientAsync<ContainerRequest> asyncClient =
AMRMClientAsync.createAMRMClientAsync(client, 20, callbackHandler);
callbackHandler.asynClient = asyncClient;
asyncClient.init(conf);
asyncClient.start();
Supplier<Boolean> checker = new Supplier<Boolean>() {
@Override
public Boolean get() {
return callbackHandler.notify;
}
};
asyncClient.registerApplicationMaster("localhost", 1234, null);
asyncClient.waitFor(checker);
Assert.assertTrue(checker.get());
}
void runCallBackThrowOutException(TestCallbackHandler2 callbackHandler) throws
InterruptedException, YarnException, IOException {
Configuration conf = new Configuration();
@ -342,7 +412,7 @@ public class TestAMRMClientAsync {
private volatile List<ContainerStatus> completedContainers;
private volatile List<Container> allocatedContainers;
Exception savedException = null;
boolean reboot = false;
volatile boolean reboot = false;
Object notifier = new Object();
int callbackCount = 0;
@ -432,7 +502,7 @@ public class TestAMRMClientAsync {
@SuppressWarnings("rawtypes")
AMRMClientAsync asynClient;
boolean stop = true;
boolean notify = false;
volatile boolean notify = false;
boolean throwOutException = false;
@Override

View File

@ -18,6 +18,7 @@
package org.apache.hadoop.yarn.client.api.impl;
import com.google.common.base.Supplier;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@ -815,6 +816,40 @@ public class TestAMRMClient {
assertEquals(0, amClient.release.size());
}
class CountDownSupplier implements Supplier<Boolean> {
int counter = 0;
@Override
public Boolean get() {
counter++;
if (counter >= 3) {
return true;
} else {
return false;
}
}
};
@Test
public void testWaitFor() throws InterruptedException {
AMRMClientImpl<ContainerRequest> amClient = null;
CountDownSupplier countDownChecker = new CountDownSupplier();
try {
// start am rm client
amClient =
(AMRMClientImpl<ContainerRequest>) AMRMClient
.<ContainerRequest> createAMRMClient();
amClient.init(new YarnConfiguration());
amClient.start();
amClient.waitFor(countDownChecker, 1000);
assertEquals(3, countDownChecker.counter);
} finally {
if (amClient != null) {
amClient.stop();
}
}
}
private void sleep(int sleepTime) {
try {
Thread.sleep(sleepTime);