From dfaac5643386cef7196c14b6dac2e39caf5c8935 Mon Sep 17 00:00:00 2001
From: Arun Suresh <asuresh@apache.org>
Date: Fri, 16 Sep 2016 16:53:18 -0700
Subject: [PATCH] YARN-5637. Changes in NodeManager to support Container
 rollback and commit. (asuresh)

(cherry picked from commit 3552c2b99dff4f21489ff284f9dcba40e897a1e5)
---
 .../ContainerManagerImpl.java                 |  68 ++++++-
 .../containermanager/container/Container.java |   4 +
 .../container/ContainerEventType.java         |   1 +
 .../container/ContainerImpl.java              | 188 +++++++++++++-----
 .../container/ContainerReInitEvent.java       |  20 +-
 .../TestContainerManagerWithLCE.java          |  42 +++-
 .../TestContainerManager.java                 | 152 ++++++++++++--
 .../nodemanager/webapp/MockContainer.java     |  10 +
 8 files changed, 401 insertions(+), 84 deletions(-)

diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/ContainerManagerImpl.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/ContainerManagerImpl.java
index ebc697f1223..9d9566e6504 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/ContainerManagerImpl.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/ContainerManagerImpl.java
@@ -163,8 +163,8 @@
 public class ContainerManagerImpl extends CompositeService implements
     ContainerManager {
 
-  private enum ReinitOp {
-    UPGRADE, COMMIT, ROLLBACK, LOCALIZE;
+  private enum ReInitOp {
+    RE_INIT, COMMIT, ROLLBACK, LOCALIZE;
   }
   /**
    * Extra duration to wait for applications to be killed on shutdown.
@@ -1455,7 +1455,7 @@ public ResourceLocalizationResponse localize(
 
     ContainerId containerId = request.getContainerId();
     Container container = preUpgradeOrLocalizeCheck(containerId,
-        ReinitOp.LOCALIZE);
+        ReInitOp.LOCALIZE);
     try {
       Map<LocalResourceVisibility, Collection<LocalResourceRequest>> req =
           container.getResourceSet().addResources(request.getLocalResources());
@@ -1471,16 +1471,31 @@ public ResourceLocalizationResponse localize(
     return ResourceLocalizationResponse.newInstance();
   }
 
-  public void upgradeContainer(ContainerId containerId,
-      ContainerLaunchContext upgradeLaunchContext) throws YarnException {
+  /**
+   * ReInitialize a container using a new Launch Context. If the
+   * retryFailureContext is not provided, The container is
+   * terminated on Failure.
+   *
+   * NOTE: Auto-Commit is true by default. This also means that the rollback
+   *       context is purged as soon as the command to start the new process
+   *       is sent. (The Container moves to RUNNING state)
+   *
+   * @param containerId Container Id.
+   * @param autoCommit Auto Commit flag.
+   * @param reInitLaunchContext Target Launch Context.
+   * @throws YarnException Yarn Exception.
+   */
+  public void reInitializeContainer(ContainerId containerId,
+      ContainerLaunchContext reInitLaunchContext, boolean autoCommit)
+      throws YarnException {
     Container container = preUpgradeOrLocalizeCheck(containerId,
-        ReinitOp.UPGRADE);
+        ReInitOp.RE_INIT);
     ResourceSet resourceSet = new ResourceSet();
     try {
-      resourceSet.addResources(upgradeLaunchContext.getLocalResources());
+      resourceSet.addResources(reInitLaunchContext.getLocalResources());
       dispatcher.getEventHandler().handle(
-          new ContainerReInitEvent(containerId, upgradeLaunchContext,
-              resourceSet));
+          new ContainerReInitEvent(containerId, reInitLaunchContext,
+              resourceSet, autoCommit));
       container.setIsReInitializing(true);
     } catch (URISyntaxException e) {
       LOG.info("Error when parsing local resource URI for upgrade of" +
@@ -1489,8 +1504,41 @@ public void upgradeContainer(ContainerId containerId,
     }
   }
 
+  /**
+   * Rollback the last reInitialization, if possible.
+   * @param containerId Container ID.
+   * @throws YarnException Yarn Exception.
+   */
+  public void rollbackReInitialization(ContainerId containerId)
+      throws YarnException {
+    Container container = preUpgradeOrLocalizeCheck(containerId,
+        ReInitOp.ROLLBACK);
+    if (container.canRollback()) {
+      dispatcher.getEventHandler().handle(
+          new ContainerEvent(containerId, ContainerEventType.ROLLBACK_REINIT));
+    } else {
+      throw new YarnException("Nothing to rollback to !!");
+    }
+  }
+
+  /**
+   * Commit last reInitialization after which no rollback will be possible.
+   * @param containerId Container ID.
+   * @throws YarnException Yarn Exception.
+   */
+  public void commitReInitialization(ContainerId containerId)
+      throws YarnException {
+    Container container = preUpgradeOrLocalizeCheck(containerId,
+        ReInitOp.COMMIT);
+    if (container.canRollback()) {
+      container.commitUpgrade();
+    } else {
+      throw new YarnException("Nothing to Commit !!");
+    }
+  }
+
   private Container preUpgradeOrLocalizeCheck(ContainerId containerId,
-      ReinitOp op) throws YarnException {
+      ReInitOp op) throws YarnException {
     Container container = context.getContainers().get(containerId);
     if (container == null) {
       throw new YarnException("Specified " + containerId + " does not exist!");
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/container/Container.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/container/Container.java
index 03a7a573a2e..f8a7e35393a 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/container/Container.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/container/Container.java
@@ -79,4 +79,8 @@ public interface Container extends EventHandler<ContainerEvent> {
   void setIsReInitializing(boolean isReInitializing);
 
   boolean isReInitializing();
+
+  boolean canRollback();
+
+  void commitUpgrade();
 }
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/container/ContainerEventType.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/container/ContainerEventType.java
index 0b57505d10d..afea0e6cbd0 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/container/ContainerEventType.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/container/ContainerEventType.java
@@ -26,6 +26,7 @@ public enum ContainerEventType {
   UPDATE_DIAGNOSTICS_MSG,
   CONTAINER_DONE,
   REINITIALIZE_CONTAINER,
+  ROLLBACK_REINIT,
 
   // DownloadManager
   CONTAINER_INITED,
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/container/ContainerImpl.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/container/ContainerImpl.java
index a98d3053051..3704cfd5b04 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/container/ContainerImpl.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/container/ContainerImpl.java
@@ -90,14 +90,42 @@
 
 public class ContainerImpl implements Container {
 
-  private final static class ReInitializationContext {
-    private final ResourceSet resourceSet;
+  private static final class ReInitializationContext {
     private final ContainerLaunchContext newLaunchContext;
+    private final ResourceSet newResourceSet;
+
+    // Rollback state
+    private final ContainerLaunchContext oldLaunchContext;
+    private final ResourceSet oldResourceSet;
 
     private ReInitializationContext(ContainerLaunchContext newLaunchContext,
-        ResourceSet resourceSet) {
+        ResourceSet newResourceSet,
+        ContainerLaunchContext oldLaunchContext,
+        ResourceSet oldResourceSet) {
       this.newLaunchContext = newLaunchContext;
-      this.resourceSet = resourceSet;
+      this.newResourceSet = newResourceSet;
+      this.oldLaunchContext = oldLaunchContext;
+      this.oldResourceSet = oldResourceSet;
+    }
+
+    private boolean canRollback() {
+      return (oldLaunchContext != null);
+    }
+
+    private ResourceSet mergedResourceSet() {
+      if (oldLaunchContext == null) {
+        return newResourceSet;
+      }
+      return ResourceSet.merge(oldResourceSet, newResourceSet);
+    }
+
+    private ReInitializationContext createContextForRollback() {
+      if (oldLaunchContext == null) {
+        return null;
+      } else {
+        return new ReInitializationContext(
+            oldLaunchContext, oldResourceSet, null, null);
+      }
     }
   }
 
@@ -128,7 +156,7 @@ private ReInitializationContext(ContainerLaunchContext newLaunchContext,
   private String logDir;
   private String host;
   private String ips;
-  private ReInitializationContext reInitContext;
+  private volatile ReInitializationContext reInitContext;
   private volatile boolean isReInitializing = false;
 
   /** The NM-wide configuration - not specific to this container */
@@ -186,8 +214,8 @@ public ContainerImpl(Configuration conf, Dispatcher dispatcher,
     }
 
     // Configure the Retry Context
-    this.containerRetryContext =
-        configureRetryContext(conf, launchContext, this.containerId);
+    this.containerRetryContext = configureRetryContext(
+        conf, launchContext, this.containerId);
     this.remainingRetryAttempts = this.containerRetryContext.getMaxRetries();
     stateMachine = stateMachineFactory.make(this);
     this.resourceSet = new ResourceSet();
@@ -318,12 +346,16 @@ ContainerEventType.KILL_CONTAINER, new KillTransition())
         new ExitedWithSuccessTransition(true))
     .addTransition(ContainerState.RUNNING,
         EnumSet.of(ContainerState.RELAUNCHING,
+            ContainerState.LOCALIZED,
             ContainerState.EXITED_WITH_FAILURE),
         ContainerEventType.CONTAINER_EXITED_WITH_FAILURE,
         new RetryFailureTransition())
     .addTransition(ContainerState.RUNNING, ContainerState.REINITIALIZING,
         ContainerEventType.REINITIALIZE_CONTAINER,
         new ReInitializeContainerTransition())
+    .addTransition(ContainerState.RUNNING, ContainerState.REINITIALIZING,
+        ContainerEventType.ROLLBACK_REINIT,
+        new RollbackContainerTransition())
     .addTransition(ContainerState.RUNNING, ContainerState.RUNNING,
         ContainerEventType.RESOURCE_LOCALIZED,
         new ResourceLocalizedWhileRunningTransition())
@@ -875,15 +907,15 @@ static class ReInitializeContainerTransition extends ContainerTransition {
     @SuppressWarnings("unchecked")
     @Override
     public void transition(ContainerImpl container, ContainerEvent event) {
-      container.reInitContext = createReInitContext(event);
+      container.reInitContext = createReInitContext(container, event);
       try {
         Map<LocalResourceVisibility, Collection<LocalResourceRequest>>
-            pendingResources =
-            container.reInitContext.resourceSet.getAllResourcesByVisibility();
-        if (!pendingResources.isEmpty()) {
+            resByVisibility = container.reInitContext.newResourceSet
+            .getAllResourcesByVisibility();
+        if (!resByVisibility.isEmpty()) {
           container.dispatcher.getEventHandler().handle(
               new ContainerLocalizationRequestEvent(
-                  container, pendingResources));
+                  container, resByVisibility));
         } else {
           // We are not waiting on any resources, so...
           // Kill the current container.
@@ -900,10 +932,30 @@ public void transition(ContainerImpl container, ContainerEvent event) {
     }
 
     protected ReInitializationContext createReInitContext(
-        ContainerEvent event) {
-      ContainerReInitEvent rEvent = (ContainerReInitEvent)event;
-      return new ReInitializationContext(rEvent.getReInitLaunchContext(),
-          rEvent.getResourceSet());
+        ContainerImpl container, ContainerEvent event) {
+      ContainerReInitEvent reInitEvent = (ContainerReInitEvent)event;
+      return new ReInitializationContext(
+          reInitEvent.getReInitLaunchContext(),
+          reInitEvent.getResourceSet(),
+          // If AutoCommit is turned on, then no rollback can happen...
+          // So don't need to store the previous context.
+          (reInitEvent.isAutoCommit() ? null : container.launchContext),
+          (reInitEvent.isAutoCommit() ? null : container.resourceSet));
+    }
+  }
+
+  /**
+   * Transition to start the Rollback process.
+   */
+  static class RollbackContainerTransition extends
+      ReInitializeContainerTransition {
+
+    @Override
+    protected ReInitializationContext createReInitContext(ContainerImpl
+        container, ContainerEvent event) {
+      LOG.warn("Container [" + container.getContainerId() + "]" +
+          " about to be explicitly Rolledback !!");
+      return container.reInitContext.createContextForRollback();
     }
   }
 
@@ -919,10 +971,10 @@ static class ResourceLocalizedWhileReInitTransition
     public void transition(ContainerImpl container, ContainerEvent event) {
       ContainerResourceLocalizedEvent rsrcEvent =
           (ContainerResourceLocalizedEvent) event;
-      container.reInitContext.resourceSet.resourceLocalized(
+      container.reInitContext.newResourceSet.resourceLocalized(
           rsrcEvent.getResource(), rsrcEvent.getLocation());
       // Check if all ResourceLocalization has completed
-      if (container.reInitContext.resourceSet.getPendingResources()
+      if (container.reInitContext.newResourceSet.getPendingResources()
           .isEmpty()) {
         // Kill the current container.
         container.dispatcher.getEventHandler().handle(
@@ -1019,10 +1071,13 @@ public void transition(ContainerImpl container, ContainerEvent event) {
       container.metrics.runningContainer();
       container.wasLaunched  = true;
 
-      if (container.reInitContext != null) {
+      container.setIsReInitializing(false);
+      // Check if this launch was due to a re-initialization.
+      // If autocommit == true, then wipe the re-init context. This ensures
+      // that any subsequent failures do not trigger a rollback.
+      if (container.reInitContext != null
+          && !container.reInitContext.canRollback()) {
         container.reInitContext = null;
-        // Set rollback context here..
-        container.setIsReInitializing(false);
       }
 
       if (container.recoveredAsKilled) {
@@ -1139,36 +1194,50 @@ public ContainerState transition(final ContainerImpl container,
                     + container.getContainerId(), e);
           }
         }
-        LOG.info("Relaunching Container " + container.getContainerId()
-            + ". Remaining retry attempts(after relaunch) : "
-            + container.remainingRetryAttempts
-            + ". Interval between retries is "
-            + container.containerRetryContext.getRetryInterval() + "ms");
-        container.wasLaunched  = false;
-        container.metrics.endRunningContainer();
-        if (container.containerRetryContext.getRetryInterval() == 0) {
-          container.sendRelaunchEvent();
-        } else {
-          // wait for some time, then send launch event
-          new Thread() {
-            @Override
-            public void run() {
-              try {
-                Thread.sleep(
-                    container.containerRetryContext.getRetryInterval());
-                container.sendRelaunchEvent();
-              } catch (InterruptedException e) {
-                return;
-              }
-            }
-          }.start();
-        }
+        doRelaunch(container, container.remainingRetryAttempts,
+            container.containerRetryContext.getRetryInterval());
         return ContainerState.RELAUNCHING;
+      } else if (container.canRollback()) {
+        // Rollback is possible only if the previous launch context is
+        // available.
+        container.addDiagnostics("Container Re-init Auto Rolled-Back.");
+        LOG.info("Rolling back Container reInitialization for [" +
+            container.getContainerId() + "] !!");
+        container.reInitContext =
+            container.reInitContext.createContextForRollback();
+        new KilledForReInitializationTransition().transition(container, event);
+        return ContainerState.LOCALIZED;
       } else {
         new ExitedWithFailureTransition(true).transition(container, event);
         return ContainerState.EXITED_WITH_FAILURE;
       }
     }
+
+    private void doRelaunch(final ContainerImpl container,
+        int remainingRetryAttempts, final int retryInterval) {
+      LOG.info("Relaunching Container " + container.getContainerId()
+          + ". Remaining retry attempts(after relaunch) : "
+          + remainingRetryAttempts + ". Interval between retries is "
+          + retryInterval + "ms");
+      container.wasLaunched  = false;
+      container.metrics.endRunningContainer();
+      if (retryInterval == 0) {
+        container.sendRelaunchEvent();
+      } else {
+        // wait for some time, then send launch event
+        new Thread() {
+          @Override
+          public void run() {
+            try {
+              Thread.sleep(retryInterval);
+              container.sendRelaunchEvent();
+            } catch (InterruptedException e) {
+              return;
+            }
+          }
+        }.start();
+      }
+    }
   }
 
   @Override
@@ -1179,24 +1248,29 @@ public boolean isRetryContextSet() {
 
   @Override
   public boolean shouldRetry(int errorCode) {
+    return shouldRetry(errorCode, containerRetryContext,
+        remainingRetryAttempts);
+  }
+
+  public static boolean shouldRetry(int errorCode,
+      ContainerRetryContext retryContext, int remainingRetryAttempts) {
     if (errorCode == ExitCode.SUCCESS.getExitCode()
         || errorCode == ExitCode.FORCE_KILLED.getExitCode()
         || errorCode == ExitCode.TERMINATED.getExitCode()) {
       return false;
     }
 
-    ContainerRetryPolicy retryPolicy = containerRetryContext.getRetryPolicy();
+    ContainerRetryPolicy retryPolicy = retryContext.getRetryPolicy();
     if (retryPolicy == ContainerRetryPolicy.RETRY_ON_ALL_ERRORS
         || (retryPolicy == ContainerRetryPolicy.RETRY_ON_SPECIFIC_ERROR_CODES
-            && containerRetryContext.getErrorCodes() != null
-            && containerRetryContext.getErrorCodes().contains(errorCode))) {
+        && retryContext.getErrorCodes() != null
+        && retryContext.getErrorCodes().contains(errorCode))) {
       return remainingRetryAttempts > 0
           || remainingRetryAttempts == ContainerRetryContext.RETRY_FOREVER;
     }
 
     return false;
   }
-
   /**
    * Transition to EXITED_WITH_FAILURE
    */
@@ -1231,13 +1305,12 @@ public void transition(ContainerImpl container,
       // Re configure the Retry Context
       container.containerRetryContext =
           configureRetryContext(container.context.getConf(),
-          container.launchContext, container.containerId);
+              container.launchContext, container.containerId);
       // Reset the retry attempts since its a fresh start
       container.remainingRetryAttempts =
           container.containerRetryContext.getMaxRetries();
 
-      container.resourceSet = ResourceSet.merge(
-          container.resourceSet, container.reInitContext.resourceSet);
+      container.resourceSet = container.reInitContext.mergedResourceSet();
 
       container.sendLaunchEvent();
     }
@@ -1574,4 +1647,15 @@ public void setIsReInitializing(boolean isReInitializing) {
   public boolean isReInitializing() {
     return this.isReInitializing;
   }
+
+  @Override
+  public boolean canRollback() {
+    return (this.reInitContext != null)
+        && (this.reInitContext.canRollback());
+  }
+
+  @Override
+  public void commitUpgrade() {
+    this.reInitContext = null;
+  }
 }
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/container/ContainerReInitEvent.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/container/ContainerReInitEvent.java
index 2ccdbd7f65e..46eba03413c 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/container/ContainerReInitEvent.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/main/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/container/ContainerReInitEvent.java
@@ -30,18 +30,22 @@ public class ContainerReInitEvent extends ContainerEvent {
 
   private final ContainerLaunchContext reInitLaunchContext;
   private final ResourceSet resourceSet;
+  private final boolean autoCommit;
 
   /**
    * Container Re-Init Event.
-   * @param cID Container Id
-   * @param upgradeContext Upgrade context
-   * @param resourceSet Resource Set
+   * @param cID Container Id.
+   * @param upgradeContext Upgrade Context.
+   * @param resourceSet Resource Set.
+   * @param autoCommit Auto Commit.
    */
   public ContainerReInitEvent(ContainerId cID,
-      ContainerLaunchContext upgradeContext, ResourceSet resourceSet){
+      ContainerLaunchContext upgradeContext,
+      ResourceSet resourceSet, boolean autoCommit){
     super(cID, ContainerEventType.REINITIALIZE_CONTAINER);
     this.reInitLaunchContext = upgradeContext;
     this.resourceSet = resourceSet;
+    this.autoCommit = autoCommit;
   }
 
   /**
@@ -59,4 +63,12 @@ public ContainerLaunchContext getReInitLaunchContext() {
   public ResourceSet getResourceSet() {
     return resourceSet;
   }
+
+  /**
+   * Should this re-Initialization be auto-committed.
+   * @return AutoCommit.
+   */
+  public boolean isAutoCommit() {
+    return autoCommit;
+  }
 }
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/TestContainerManagerWithLCE.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/TestContainerManagerWithLCE.java
index 8a278494a43..79182cefa50 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/TestContainerManagerWithLCE.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/TestContainerManagerWithLCE.java
@@ -270,15 +270,15 @@ public void testForcefulShutdownSignal() throws IOException,
   }
 
   @Override
-  public void testContainerUpgradeSuccess() throws IOException,
+  public void testContainerUpgradeSuccessAutoCommit() throws IOException,
       InterruptedException, YarnException {
     // Don't run the test if the binary is not available.
     if (!shouldRunTest()) {
       LOG.info("LCE binary path is not passed. Not running the test");
       return;
     }
-    LOG.info("Running testContainerUpgradeSuccess");
-    super.testContainerUpgradeSuccess();
+    LOG.info("Running testContainerUpgradeSuccessAutoCommit");
+    super.testContainerUpgradeSuccessAutoCommit();
   }
 
   @Override
@@ -293,6 +293,42 @@ public void testContainerUpgradeLocalizationFailure() throws IOException,
     super.testContainerUpgradeLocalizationFailure();
   }
 
+  @Override
+  public void testContainerUpgradeSuccessExplicitCommit() throws IOException,
+    InterruptedException, YarnException {
+    // Don't run the test if the binary is not available.
+    if (!shouldRunTest()) {
+      LOG.info("LCE binary path is not passed. Not running the test");
+      return;
+    }
+    LOG.info("Running testContainerUpgradeSuccessExplicitCommit");
+    super.testContainerUpgradeSuccessExplicitCommit();
+  }
+
+  @Override
+  public void testContainerUpgradeSuccessExplicitRollback() throws IOException,
+      InterruptedException, YarnException {
+    // Don't run the test if the binary is not available.
+    if (!shouldRunTest()) {
+      LOG.info("LCE binary path is not passed. Not running the test");
+      return;
+    }
+    LOG.info("Running testContainerUpgradeSuccessExplicitRollback");
+    super.testContainerUpgradeSuccessExplicitRollback();
+  }
+
+  @Override
+  public void testContainerUpgradeRollbackDueToFailure() throws IOException,
+      InterruptedException, YarnException {
+    // Don't run the test if the binary is not available.
+    if (!shouldRunTest()) {
+      LOG.info("LCE binary path is not passed. Not running the test");
+      return;
+    }
+    LOG.info("Running testContainerUpgradeRollbackDueToFailure");
+    super.testContainerUpgradeRollbackDueToFailure();
+  }
+
   @Override
   public void testContainerUpgradeProcessFailure() throws IOException,
       InterruptedException, YarnException {
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/TestContainerManager.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/TestContainerManager.java
index 73725f6aa9f..b0674ace1c4 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/TestContainerManager.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/containermanager/TestContainerManager.java
@@ -369,9 +369,8 @@ public void testContainerLaunchAndStop() throws IOException,
       DefaultContainerExecutor.containerIsAlive(pid));
   }
 
-  @Test
-  public void testContainerUpgradeSuccess() throws IOException,
-      InterruptedException, YarnException {
+  private String[] testContainerUpgradeSuccess(boolean autoCommit)
+      throws IOException, InterruptedException, YarnException {
     containerManager.start();
     // ////// Construct the Container-id
     ContainerId cId = createContainerId(0);
@@ -381,7 +380,7 @@ public void testContainerUpgradeSuccess() throws IOException,
 
     File newStartFile = new File(tmpDir, "start_file_n.txt").getAbsoluteFile();
 
-    prepareContainerUpgrade(false, false, cId, newStartFile);
+    prepareContainerUpgrade(autoCommit, false, false, cId, newStartFile);
 
     // Assert that the First process is not alive anymore
     Assert.assertFalse("Process is still alive!",
@@ -407,6 +406,80 @@ public void testContainerUpgradeSuccess() throws IOException,
     // Assert that the New process is alive
     Assert.assertTrue("New Process is not alive!",
         DefaultContainerExecutor.containerIsAlive(newPid));
+    return new String[]{pid, newPid};
+  }
+
+  @Test
+  public void testContainerUpgradeSuccessAutoCommit() throws IOException,
+      InterruptedException, YarnException {
+    testContainerUpgradeSuccess(true);
+    // Should not be able to Commit (since already auto committed)
+    try {
+      containerManager.commitReInitialization(createContainerId(0));
+      Assert.fail();
+    } catch (Exception e) {
+      Assert.assertTrue(e.getMessage().contains("Nothing to Commit"));
+    }
+  }
+
+  @Test
+  public void testContainerUpgradeSuccessExplicitCommit() throws IOException,
+      InterruptedException, YarnException {
+    testContainerUpgradeSuccess(false);
+    ContainerId cId = createContainerId(0);
+    containerManager.commitReInitialization(cId);
+    // Should not be able to Rollback once committed
+    try {
+      containerManager.rollbackReInitialization(cId);
+      Assert.fail();
+    } catch (Exception e) {
+      Assert.assertTrue(e.getMessage().contains("Nothing to rollback to"));
+    }
+  }
+
+  @Test
+  public void testContainerUpgradeSuccessExplicitRollback() throws IOException,
+      InterruptedException, YarnException {
+    String[] pids = testContainerUpgradeSuccess(false);
+
+    // Delete the old start File..
+    File oldStartFile = new File(tmpDir, "start_file_o.txt").getAbsoluteFile();
+    oldStartFile.delete();
+
+    ContainerId cId = createContainerId(0);
+    // Explicit Rollback
+    containerManager.rollbackReInitialization(cId);
+
+    // Original should be dead anyway
+    Assert.assertFalse("Original Process is still alive!",
+        DefaultContainerExecutor.containerIsAlive(pids[0]));
+
+    // Wait for upgraded process to die
+    int timeoutSecs = 0;
+    while (!DefaultContainerExecutor.containerIsAlive(pids[1])
+        && timeoutSecs++ < 20) {
+      Thread.sleep(1000);
+      LOG.info("Waiting for Upgraded process to die..");
+    }
+
+    timeoutSecs = 0;
+    // Wait for new processStartfile to be created
+    while (!oldStartFile.exists() && timeoutSecs++ < 20) {
+      Thread.sleep(1000);
+      LOG.info("Waiting for New process start-file to be created");
+    }
+
+    // Now verify the contents of the file
+    BufferedReader reader =
+        new BufferedReader(new FileReader(oldStartFile));
+    Assert.assertEquals("Hello World!", reader.readLine());
+    // Get the pid of the process
+    String rolledBackPid = reader.readLine().trim();
+    // No more lines
+    Assert.assertEquals(null, reader.readLine());
+
+    Assert.assertNotEquals("The Rolled-back process should be a different pid",
+        pids[0], rolledBackPid);
   }
 
   @Test
@@ -424,7 +497,7 @@ public void testContainerUpgradeLocalizationFailure() throws IOException,
 
     File newStartFile = new File(tmpDir, "start_file_n.txt").getAbsoluteFile();
 
-    prepareContainerUpgrade(true, true, cId, newStartFile);
+    prepareContainerUpgrade(false, true, true, cId, newStartFile);
 
     // Assert that the First process is STILL alive
     // since upgrade was terminated..
@@ -447,22 +520,69 @@ public void testContainerUpgradeProcessFailure() throws IOException,
 
     File newStartFile = new File(tmpDir, "start_file_n.txt").getAbsoluteFile();
 
-    prepareContainerUpgrade(true, false, cId, newStartFile);
+    // Since Autocommit is true, there is also no rollback context...
+    // which implies that if the new process fails, since there is no
+    // rollback, it is terminated.
+    prepareContainerUpgrade(true, true, false, cId, newStartFile);
 
     // Assert that the First process is not alive anymore
     Assert.assertFalse("Process is still alive!",
         DefaultContainerExecutor.containerIsAlive(pid));
   }
 
+  @Test
+  public void testContainerUpgradeRollbackDueToFailure() throws IOException,
+      InterruptedException, YarnException {
+    if (Shell.WINDOWS) {
+      return;
+    }
+    containerManager.start();
+    // ////// Construct the Container-id
+    ContainerId cId = createContainerId(0);
+    File oldStartFile = new File(tmpDir, "start_file_o.txt").getAbsoluteFile();
+
+    String pid = prepareInitialContainer(cId, oldStartFile);
+
+    File newStartFile = new File(tmpDir, "start_file_n.txt").getAbsoluteFile();
+
+    prepareContainerUpgrade(false, true, false, cId, newStartFile);
+
+    // Assert that the First process is not alive anymore
+    Assert.assertFalse("Original Process is still alive!",
+        DefaultContainerExecutor.containerIsAlive(pid));
+
+    int timeoutSecs = 0;
+    // Wait for oldStartFile to be created
+    while (!oldStartFile.exists() && timeoutSecs++ < 20) {
+      System.out.println("\nFiles: " +
+          Arrays.toString(oldStartFile.getParentFile().list()));
+      Thread.sleep(1000);
+      LOG.info("Waiting for New process start-file to be created");
+    }
+
+    // Now verify the contents of the file
+    BufferedReader reader =
+        new BufferedReader(new FileReader(oldStartFile));
+    Assert.assertEquals("Hello World!", reader.readLine());
+    // Get the pid of the process
+    String rolledBackPid = reader.readLine().trim();
+    // No more lines
+    Assert.assertEquals(null, reader.readLine());
+
+    Assert.assertNotEquals("The Rolled-back process should be a different pid",
+        pid, rolledBackPid);
+  }
+
   /**
    * Prepare a launch Context for container upgrade and request the
    * Container Manager to re-initialize a running container using the
    * new launch context.
+   * @param autoCommit Enable autoCommit.
    * @param failCmd injects a start script that intentionally fails.
    * @param failLoc injects a bad file Location that will fail localization.
    */
-  private void prepareContainerUpgrade(boolean failCmd, boolean failLoc,
-      ContainerId cId, File startFile)
+  private void prepareContainerUpgrade(boolean autoCommit, boolean failCmd,
+      boolean failLoc, ContainerId cId, File startFile)
       throws FileNotFoundException, YarnException, InterruptedException {
     // Re-write scriptfile and processStartFile
     File scriptFile = Shell.appendScriptExtension(tmpDir, "scriptFile_new");
@@ -471,13 +591,15 @@ private void prepareContainerUpgrade(boolean failCmd, boolean failLoc,
     writeScriptFile(fileWriter, "Upgrade World!", startFile, cId, failCmd);
 
     ContainerLaunchContext containerLaunchContext =
-        prepareContainerLaunchContext(scriptFile, "dest_file_new", failLoc);
+        prepareContainerLaunchContext(scriptFile, "dest_file_new", failLoc, 0);
 
-    containerManager.upgradeContainer(cId, containerLaunchContext);
+    containerManager.reInitializeContainer(cId, containerLaunchContext,
+        autoCommit);
     try {
-      containerManager.upgradeContainer(cId, containerLaunchContext);
+      containerManager.reInitializeContainer(cId, containerLaunchContext,
+          autoCommit);
     } catch (Exception e) {
-      Assert.assertTrue(e.getMessage().contains("Cannot perform UPGRADE"));
+      Assert.assertTrue(e.getMessage().contains("Cannot perform RE_INIT"));
     }
     int timeoutSecs = 0;
     int maxTimeToWait = failLoc ? 10 : 20;
@@ -501,7 +623,7 @@ private String prepareInitialContainer(ContainerId cId, File startFile)
     writeScriptFile(fileWriterOld, "Hello World!", startFile, cId, false);
 
     ContainerLaunchContext containerLaunchContext =
-        prepareContainerLaunchContext(scriptFileOld, "dest_file", false);
+        prepareContainerLaunchContext(scriptFileOld, "dest_file", false, 4);
 
     StartContainerRequest scRequest =
         StartContainerRequest.newInstance(containerLaunchContext,
@@ -562,7 +684,7 @@ private void writeScriptFile(PrintWriter fileWriter, String startLine,
   }
 
   private ContainerLaunchContext prepareContainerLaunchContext(File scriptFile,
-      String destFName, boolean putBadFile) {
+      String destFName, boolean putBadFile, int numRetries) {
     ContainerLaunchContext containerLaunchContext =
         recordFactory.newRecordInstance(ContainerLaunchContext.class);
     URL resourceAlpha = null;
@@ -592,7 +714,7 @@ private ContainerLaunchContext prepareContainerLaunchContext(File scriptFile,
     ContainerRetryContext containerRetryContext = ContainerRetryContext
         .newInstance(
             ContainerRetryPolicy.RETRY_ON_SPECIFIC_ERROR_CODES,
-            new HashSet<>(Arrays.asList(Integer.valueOf(111))), 4, 0);
+            new HashSet<>(Arrays.asList(Integer.valueOf(111))), numRetries, 0);
     containerLaunchContext.setContainerRetryContext(containerRetryContext);
     List<String> commands = Arrays.asList(
         Shell.getRunScriptCommand(scriptFile));
diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/webapp/MockContainer.java b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/webapp/MockContainer.java
index d2b8d63f116..5f1aab9199f 100644
--- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/webapp/MockContainer.java
+++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-server/hadoop-yarn-server-nodemanager/src/test/java/org/apache/hadoop/yarn/server/nodemanager/webapp/MockContainer.java
@@ -199,4 +199,14 @@ public void setIsReInitializing(boolean isReInitializing) {
   public boolean isReInitializing() {
     return false;
   }
+
+  @Override
+  public boolean canRollback() {
+    return false;
+  }
+
+  @Override
+  public void commitUpgrade() {
+
+  }
 }