From cdd6efd3ab6917e30b4c5c7b261f61838901bb37 Mon Sep 17 00:00:00 2001 From: Ahmed Hussein Date: Wed, 8 Jan 2020 11:08:13 -0600 Subject: [PATCH] MAPREDUCE-7252. Handling 0 progress in SimpleExponential task runtime estimator Signed-off-by: Jonathan Eagles --- .../v2/app/speculate/DataStatistics.java | 28 ++-- ...SimpleExponentialTaskRuntimeEstimator.java | 67 ++++++--- .../forecast/SimpleExponentialSmoothing.java | 131 +++++++++++------- .../app/speculate/forecast/package-info.java | 20 +++ .../apache/hadoop/mapreduce/v2/app/MRApp.java | 42 +++++- .../v2/TestSpeculativeExecutionWithMRApp.java | 116 ++++++++++++++-- 6 files changed, 308 insertions(+), 96 deletions(-) create mode 100644 hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/forecast/package-info.java diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/DataStatistics.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/DataStatistics.java index 9f1c12243f7..036eb457142 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/DataStatistics.java +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/DataStatistics.java @@ -18,6 +18,11 @@ package org.apache.hadoop.mapreduce.v2.app.speculate; public class DataStatistics { + + /** + * factor used to calculate confidence interval within 95%. + */ + private static final double DEFAULT_CI_FACTOR = 1.96; private int count = 0; private double sum = 0; private double sumSquares = 0; @@ -25,25 +30,26 @@ public class DataStatistics { public DataStatistics() { } - public DataStatistics(double initNum) { + public DataStatistics(final double initNum) { this.count = 1; this.sum = initNum; this.sumSquares = initNum * initNum; } - public synchronized void add(double newNum) { + public synchronized void add(final double newNum) { this.count++; this.sum += newNum; this.sumSquares += newNum * newNum; } - public synchronized void updateStatistics(double old, double update) { - this.sum += update - old; - this.sumSquares += (update * update) - (old * old); + public synchronized void updateStatistics(final double old, + final double update) { + this.sum += update - old; + this.sumSquares += (update * update) - (old * old); } public synchronized double mean() { - return count == 0 ? 0.0 : sum/count; + return count == 0 ? 0.0 : sum / count; } public synchronized double var() { @@ -52,14 +58,14 @@ public class DataStatistics { return 0.0; } double mean = mean(); - return Math.max((sumSquares/count) - mean * mean, 0.0d); + return Math.max((sumSquares / count) - mean * mean, 0.0d); } public synchronized double std() { return Math.sqrt(this.var()); } - public synchronized double outlier(float sigma) { + public synchronized double outlier(final float sigma) { if (count != 0.0) { return mean() + std() * sigma; } @@ -78,10 +84,12 @@ public class DataStatistics { * @return the mean value adding 95% confidence interval */ public synchronized double meanCI() { - if (count <= 1) return 0.0; + if (count <= 1) { + return 0.0; + } double currMean = mean(); double currStd = std(); - return currMean + (1.96 * currStd / Math.sqrt(count)); + return currMean + (DEFAULT_CI_FACTOR * currStd / Math.sqrt(count)); } public String toString() { diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/SimpleExponentialTaskRuntimeEstimator.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/SimpleExponentialTaskRuntimeEstimator.java index f244b20e3e0..28389169bb5 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/SimpleExponentialTaskRuntimeEstimator.java +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/SimpleExponentialTaskRuntimeEstimator.java @@ -33,7 +33,22 @@ import org.apache.hadoop.mapreduce.v2.app.speculate.forecast.SimpleExponentialSm * A task Runtime Estimator based on exponential smoothing. */ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase { - private final static long DEFAULT_ESTIMATE_RUNTIME = -1L; + + /** + * The default value returned by the estimator when no records exist. + */ + private static final long DEFAULT_ESTIMATE_RUNTIME = -1L; + + /** + * Given a forecast of value 0.0, it is getting replaced by the default value + * to avoid division by 0. + */ + private static final double DEFAULT_PROGRESS_VALUE = 1E-10; + + /** + * Factor used to calculate the confidence interval. + */ + private static final double CONFIDENCE_INTERVAL_FACTOR = 0.25; /** * Constant time used to calculate the smoothing exponential factor. @@ -53,11 +68,15 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase { */ private long stagnatedWindow; + /** + * A map of TA Id to the statistic model of smooth exponential. + */ private final ConcurrentMap> estimates = new ConcurrentHashMap<>(); - private SimpleExponentialSmoothing getForecastEntry(TaskAttemptId attemptID) { + private SimpleExponentialSmoothing getForecastEntry( + final TaskAttemptId attemptID) { AtomicReference entryRef = estimates .get(attemptID); if (entryRef == null) { @@ -66,13 +85,13 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase { return entryRef.get(); } - private void incorporateReading(TaskAttemptId attemptID, - float newRawData, long newTimeStamp) { + private void incorporateReading(final TaskAttemptId attemptID, + final float newRawData, final long newTimeStamp) { SimpleExponentialSmoothing foreCastEntry = getForecastEntry(attemptID); if (foreCastEntry == null) { Long tStartTime = startTimes.get(attemptID); // skip if the startTime is not set yet - if(tStartTime == null) { + if (tStartTime == null) { return; } estimates.putIfAbsent(attemptID, @@ -86,7 +105,8 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase { } @Override - public void contextualize(Configuration conf, AppContext context) { + public void contextualize(final Configuration conf, + final AppContext context) { super.contextualize(conf, context); constTime @@ -103,18 +123,16 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase { } @Override - public long estimatedRuntime(TaskAttemptId id) { + public long estimatedRuntime(final TaskAttemptId id) { SimpleExponentialSmoothing foreCastEntry = getForecastEntry(id); if (foreCastEntry == null) { return DEFAULT_ESTIMATE_RUNTIME; } - // TODO: What should we do when estimate is zero - double remainingWork = Math.min(1.0, 1.0 - foreCastEntry.getRawData()); - double forecast = foreCastEntry.getForecast(); - if (forecast <= 0.0) { - return DEFAULT_ESTIMATE_RUNTIME; - } - long remainingTime = (long)(remainingWork / forecast); + double remainingWork = Math + .max(0.0, Math.min(1.0, 1.0 - foreCastEntry.getRawData())); + double forecast = Math + .max(DEFAULT_PROGRESS_VALUE, foreCastEntry.getForecast()); + long remainingTime = (long) (remainingWork / forecast); long estimatedRuntime = remainingTime + foreCastEntry.getTimeStamp() - foreCastEntry.getStartTime(); @@ -122,30 +140,32 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase { } @Override - public long estimatedNewAttemptRuntime(TaskId id) { + public long estimatedNewAttemptRuntime(final TaskId id) { DataStatistics statistics = dataStatisticsForTask(id); if (statistics == null) { - return -1L; + return DEFAULT_ESTIMATE_RUNTIME; } double statsMeanCI = statistics.meanCI(); double expectedVal = - statsMeanCI + Math.min(statsMeanCI * 0.25, statistics.std() / 2); - return (long)(expectedVal); + statsMeanCI + Math.min(statsMeanCI * CONFIDENCE_INTERVAL_FACTOR, + statistics.std() / 2); + return (long) (expectedVal); } @Override - public boolean hasStagnatedProgress(TaskAttemptId id, long timeStamp) { + public boolean hasStagnatedProgress(final TaskAttemptId id, + final long timeStamp) { SimpleExponentialSmoothing foreCastEntry = getForecastEntry(id); - if(foreCastEntry == null) { + if (foreCastEntry == null) { return false; } return foreCastEntry.isDataStagnated(timeStamp); } @Override - public long runtimeEstimateVariance(TaskAttemptId id) { + public long runtimeEstimateVariance(final TaskAttemptId id) { SimpleExponentialSmoothing forecastEntry = getForecastEntry(id); if (forecastEntry == null) { return DEFAULT_ESTIMATE_RUNTIME; @@ -154,12 +174,13 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase { if (forecastEntry.isDefaultForecast(forecast)) { return DEFAULT_ESTIMATE_RUNTIME; } - //TODO: What is the best way to measure variance in runtime + //TODO What is the best way to measure variance in runtime return 0L; } @Override - public void updateAttempt(TaskAttemptStatus status, long timestamp) { + public void updateAttempt(final TaskAttemptStatus status, + final long timestamp) { super.updateAttempt(status, timestamp); TaskAttemptId attemptID = status.id; diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/forecast/SimpleExponentialSmoothing.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/forecast/SimpleExponentialSmoothing.java index e1ef7bec907..0e00068296a 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/forecast/SimpleExponentialSmoothing.java +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/forecast/SimpleExponentialSmoothing.java @@ -24,108 +24,145 @@ import java.util.concurrent.atomic.AtomicReference; * Implementation of the static model for Simple exponential smoothing. */ public class SimpleExponentialSmoothing { - public final static double DEFAULT_FORECAST = -1.0; + private static final double DEFAULT_FORECAST = -1.0; private final int kMinimumReads; private final long kStagnatedWindow; private final long startTime; private long timeConstant; + /** + * Holds reference to the current forecast record. + */ private AtomicReference forecastRefEntry; - public static SimpleExponentialSmoothing createForecast(long timeConstant, - int skipCnt, long stagnatedWindow, long timeStamp) { + public static SimpleExponentialSmoothing createForecast( + final long timeConstant, + final int skipCnt, final long stagnatedWindow, final long timeStamp) { return new SimpleExponentialSmoothing(timeConstant, skipCnt, stagnatedWindow, timeStamp); } - SimpleExponentialSmoothing(long ktConstant, int skipCnt, - long stagnatedWindow, long timeStamp) { - kMinimumReads = skipCnt; - kStagnatedWindow = stagnatedWindow; + SimpleExponentialSmoothing(final long ktConstant, final int skipCnt, + final long stagnatedWindow, final long timeStamp) { + this.kMinimumReads = skipCnt; + this.kStagnatedWindow = stagnatedWindow; this.timeConstant = ktConstant; this.startTime = timeStamp; this.forecastRefEntry = new AtomicReference(null); } private class ForecastRecord { - private double alpha; - private long timeStamp; - private double sample; - private double rawData; + private final double alpha; + private final long timeStamp; + private final double sample; + private final double rawData; private double forecast; - private double sseError; - private long myIndex; + private final double sseError; + private final long myIndex; + private ForecastRecord prevRec; - ForecastRecord(double forecast, double rawData, long timeStamp) { - this(0.0, forecast, rawData, forecast, timeStamp, 0.0, 0); + ForecastRecord(final double currForecast, final double currRawData, + final long currTimeStamp) { + this(0.0, currForecast, currRawData, currForecast, currTimeStamp, 0.0, 0); } - ForecastRecord(double alpha, double sample, double rawData, - double forecast, long timeStamp, double accError, long index) { - this.timeStamp = timeStamp; - this.alpha = alpha; - this.sseError = 0.0; - this.sample = sample; - this.forecast = forecast; - this.rawData = rawData; + ForecastRecord(final double alphaVal, final double currSample, + final double currRawData, + final double currForecast, final long currTimeStamp, + final double accError, + final long index) { + this.timeStamp = currTimeStamp; + this.alpha = alphaVal; + this.sample = currSample; + this.forecast = currForecast; + this.rawData = currRawData; this.sseError = accError; this.myIndex = index; } - private double preProcessRawData(double rData, long newTime) { + private ForecastRecord createForecastRecord(final double alphaVal, + final double currSample, + final double currRawData, + final double currForecast, final long currTimeStamp, + final double accError, + final long index, + final ForecastRecord prev) { + ForecastRecord forecastRec = + new ForecastRecord(alphaVal, currSample, currRawData, currForecast, + currTimeStamp, accError, index); + forecastRec.prevRec = prev; + return forecastRec; + } + + private double preProcessRawData(final double rData, final long newTime) { return processRawData(this.rawData, this.timeStamp, rData, newTime); } - public ForecastRecord append(long newTimeStamp, double rData) { - if (this.timeStamp > newTimeStamp) { + public ForecastRecord append(final long newTimeStamp, final double rData) { + if (this.timeStamp >= newTimeStamp + && Double.compare(this.rawData, rData) >= 0) { + // progress reported twice. Do nothing. return this; } - double newSample = preProcessRawData(rData, newTimeStamp); + ForecastRecord refRecord = this; + if (newTimeStamp == this.timeStamp) { + // we need to restore old value if possible + if (this.prevRec != null) { + refRecord = this.prevRec; + } + } + double newSample = refRecord.preProcessRawData(rData, newTimeStamp); long deltaTime = this.timeStamp - newTimeStamp; - if (this.myIndex == kMinimumReads) { + if (refRecord.myIndex == kMinimumReads) { timeConstant = Math.max(timeConstant, newTimeStamp - startTime); } double smoothFactor = 1 - Math.exp(((double) deltaTime) / timeConstant); double forecastVal = - smoothFactor * newSample + (1.0 - smoothFactor) * this.forecast; + smoothFactor * newSample + (1.0 - smoothFactor) * refRecord.forecast; double newSSEError = - this.sseError + Math.pow(newSample - this.forecast, 2); - return new ForecastRecord(smoothFactor, newSample, rData, forecastVal, - newTimeStamp, newSSEError, this.myIndex + 1); + refRecord.sseError + Math.pow(newSample - refRecord.forecast, 2); + return refRecord + .createForecastRecord(smoothFactor, newSample, rData, forecastVal, + newTimeStamp, newSSEError, refRecord.myIndex + 1, refRecord); } - } - public boolean isDataStagnated(long timeStamp) { + /** + * checks if the task is hanging up. + * @param timeStamp current time of the scan. + * @return true if we have number of samples > kMinimumReads and the record + * timestamp has expired. + */ + public boolean isDataStagnated(final long timeStamp) { ForecastRecord rec = forecastRefEntry.get(); - if (rec != null && rec.myIndex <= kMinimumReads) { - return (rec.timeStamp + kStagnatedWindow) < timeStamp; + if (rec != null && rec.myIndex > kMinimumReads) { + return (rec.timeStamp + kStagnatedWindow) > timeStamp; } return false; } - static double processRawData(double oldRawData, long oldTime, - double newRawData, long newTime) { + static double processRawData(final double oldRawData, final long oldTime, + final double newRawData, final long newTime) { double rate = (newRawData - oldRawData) / (newTime - oldTime); return rate; } - public void incorporateReading(long timeStamp, double rawData) { + public void incorporateReading(final long timeStamp, + final double currRawData) { ForecastRecord oldRec = forecastRefEntry.get(); if (oldRec == null) { double oldForecast = - processRawData(0, startTime, rawData, timeStamp); + processRawData(0, startTime, currRawData, timeStamp); forecastRefEntry.compareAndSet(null, new ForecastRecord(oldForecast, 0.0, startTime)); - incorporateReading(timeStamp, rawData); + incorporateReading(timeStamp, currRawData); return; } while (!forecastRefEntry.compareAndSet(oldRec, oldRec.append(timeStamp, - rawData))) { + currRawData))) { oldRec = forecastRefEntry.get(); } - } public double getForecast() { @@ -136,7 +173,7 @@ public class SimpleExponentialSmoothing { return DEFAULT_FORECAST; } - public boolean isDefaultForecast(double value) { + public boolean isDefaultForecast(final double value) { return value == DEFAULT_FORECAST; } @@ -148,7 +185,7 @@ public class SimpleExponentialSmoothing { return DEFAULT_FORECAST; } - public boolean isErrorWithinBound(double bound) { + public boolean isErrorWithinBound(final double bound) { double squaredErr = getSSE(); if (squaredErr < 0) { return false; @@ -185,8 +222,8 @@ public class SimpleExponentialSmoothing { String res = "NULL"; ForecastRecord rec = forecastRefEntry.get(); if (rec != null) { - res = "rec.index = " + rec.myIndex + ", forecast t: " + rec.timeStamp + - ", forecast: " + rec.forecast + res = "rec.index = " + rec.myIndex + ", forecast t: " + rec.timeStamp + + ", forecast: " + rec.forecast + ", sample: " + rec.sample + ", raw: " + rec.rawData + ", error: " + rec.sseError + ", alpha: " + rec.alpha; } diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/forecast/package-info.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/forecast/package-info.java new file mode 100644 index 00000000000..52b8955fb2a --- /dev/null +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/main/java/org/apache/hadoop/mapreduce/v2/app/speculate/forecast/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +@InterfaceAudience.Private +package org.apache.hadoop.mapreduce.v2.app.speculate.forecast; +import org.apache.hadoop.classification.InterfaceAudience; diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/test/java/org/apache/hadoop/mapreduce/v2/app/MRApp.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/test/java/org/apache/hadoop/mapreduce/v2/app/MRApp.java index a6e57ca8046..70ea18a13b3 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/test/java/org/apache/hadoop/mapreduce/v2/app/MRApp.java +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-app/src/test/java/org/apache/hadoop/mapreduce/v2/app/MRApp.java @@ -22,8 +22,11 @@ import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.net.InetSocketAddress; +import java.util.Arrays; import java.util.EnumSet; +import java.util.List; +import java.util.stream.Collectors; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileContext; import org.apache.hadoop.fs.FileSystem; @@ -372,18 +375,45 @@ public class MRApp extends MRAppMaster { TaskAttemptReport report = attempt.getReport(); while (!finalState.equals(report.getTaskAttemptState()) && timeoutSecs++ < 20) { - System.out.println("TaskAttempt State is : " + report.getTaskAttemptState() + - " Waiting for state : " + finalState + - " progress : " + report.getProgress()); + System.out.println( + "TaskAttempt " + attempt.getID().toString() + " State is : " + + report.getTaskAttemptState() + + " Waiting for state : " + finalState + + " progress : " + report.getProgress()); report = attempt.getReport(); Thread.sleep(500); } - System.out.println("TaskAttempt State is : " + report.getTaskAttemptState()); + System.out.println("TaskAttempt State is : " + + report.getTaskAttemptState()); Assert.assertEquals("TaskAttempt state is not correct (timedout)", - finalState, + finalState, report.getTaskAttemptState()); } + public void waitForState(TaskAttempt attempt, + TaskAttemptState...finalStates) throws Exception { + int timeoutSecs = 0; + TaskAttemptReport report = attempt.getReport(); + List targetStates = Arrays.asList(finalStates); + String statesValues = targetStates.stream().map(Object::toString).collect( + Collectors.joining(",")); + while (!targetStates.contains(report.getTaskAttemptState()) && + timeoutSecs++ < 20) { + System.out.println( + "TaskAttempt " + attempt.getID().toString() + " State is : " + + report.getTaskAttemptState() + + " Waiting for states: " + statesValues + + ". curent state is : " + report.getTaskAttemptState() + + ". progress : " + report.getProgress()); + report = attempt.getReport(); + Thread.sleep(500); + } + System.out.println("TaskAttempt State is : " + + report.getTaskAttemptState()); + Assert.assertTrue("TaskAttempt state is not correct (timedout)", + targetStates.contains(report.getTaskAttemptState())); + } + public void waitForState(Task task, TaskState finalState) throws Exception { int timeoutSecs = 0; TaskReport report = task.getReport(); @@ -396,7 +426,7 @@ public class MRApp extends MRAppMaster { Thread.sleep(500); } System.out.println("Task State is : " + report.getTaskState()); - Assert.assertEquals("Task state is not correct (timedout)", finalState, + Assert.assertEquals("Task state is not correct (timedout)", finalState, report.getTaskState()); } diff --git a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-jobclient/src/test/java/org/apache/hadoop/mapreduce/v2/TestSpeculativeExecutionWithMRApp.java b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-jobclient/src/test/java/org/apache/hadoop/mapreduce/v2/TestSpeculativeExecutionWithMRApp.java index 940f142fdf7..d4d432b94d8 100644 --- a/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-jobclient/src/test/java/org/apache/hadoop/mapreduce/v2/TestSpeculativeExecutionWithMRApp.java +++ b/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-jobclient/src/test/java/org/apache/hadoop/mapreduce/v2/TestSpeculativeExecutionWithMRApp.java @@ -18,11 +18,14 @@ package org.apache.hadoop.mapreduce.v2; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; import java.util.Arrays; import java.util.Collection; import java.util.Iterator; import java.util.Map; import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import org.apache.hadoop.mapreduce.MRJobConfig; @@ -50,19 +53,94 @@ import org.apache.hadoop.yarn.event.EventHandler; import org.apache.hadoop.yarn.util.Clock; import org.apache.hadoop.yarn.util.ControlledClock; import org.apache.hadoop.yarn.util.SystemClock; +import org.junit.Rule; import org.junit.Test; import com.google.common.base.Supplier; +import org.junit.rules.TestRule; +import org.junit.runner.Description; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import org.junit.runners.model.Statement; +/** + * The type Test speculative execution with mr app. + * It test the speculation behavior given a list of estimator classes. + */ @SuppressWarnings({ "unchecked", "rawtypes" }) @RunWith(Parameterized.class) public class TestSpeculativeExecutionWithMRApp { - + /** Number of times to re-try the failing tests. */ + private static final int ASSERT_SPECULATIONS_COUNT_RETRIES = 3; private static final int NUM_MAPPERS = 5; private static final int NUM_REDUCERS = 0; + /** + * Speculation has non-deterministic behavior due to racing and timing. Use + * retry to verify that junit tests can pass. + */ + @Retention(RetentionPolicy.RUNTIME) + public @interface Retry {} + + /** + * The type Retry rule. + */ + class RetryRule implements TestRule { + + private AtomicInteger retryCount; + + /** + * Instantiates a new Retry rule. + * + * @param retries the retries + */ + RetryRule(int retries) { + super(); + this.retryCount = new AtomicInteger(retries); + } + + @Override + public Statement apply(final Statement base, + final Description description) { + return new Statement() { + @Override + public void evaluate() throws Throwable { + Throwable caughtThrowable = null; + + while (retryCount.getAndDecrement() > 0) { + try { + base.evaluate(); + return; + } catch (Throwable t) { + if (retryCount.get() > 0 && + description.getAnnotation(Retry.class) != null) { + caughtThrowable = t; + System.out.println( + description.getDisplayName() + + ": Failed, " + + retryCount.toString() + + " retries remain"); + } else { + throw caughtThrowable; + } + } + } + } + }; + } + } + + /** + * The Rule. + */ + @Rule + public RetryRule rule = new RetryRule(ASSERT_SPECULATIONS_COUNT_RETRIES); + + /** + * Gets test parameters. + * + * @return the test parameters + */ @Parameterized.Parameters(name = "{index}: TaskEstimator(EstimatorClass {0})") public static Collection getTestParameters() { return Arrays.asList(new Object[][] { @@ -73,12 +151,23 @@ public class TestSpeculativeExecutionWithMRApp { private Class estimatorClass; + /** + * Instantiates a new Test speculative execution with mr app. + * + * @param estimatorKlass the estimator klass + */ public TestSpeculativeExecutionWithMRApp( Class estimatorKlass) { this.estimatorClass = estimatorKlass; } - @Test + /** + * Test speculate successful without update events. + * + * @throws Exception the exception + */ + @Retry + @Test (timeout = 360000) public void testSpeculateSuccessfulWithoutUpdateEvents() throws Exception { Clock actualClock = SystemClock.getInstance(); @@ -128,7 +217,8 @@ public class TestSpeculativeExecutionWithMRApp { TaskAttemptEventType.TA_DONE)); appEventHandler.handle(new TaskAttemptEvent(taskAttempt.getKey(), TaskAttemptEventType.TA_CONTAINER_COMPLETED)); - app.waitForState(taskAttempt.getValue(), TaskAttemptState.SUCCEEDED); + app.waitForState(taskAttempt.getValue(), TaskAttemptState.SUCCEEDED, + TaskAttemptState.KILLED); } } } @@ -150,8 +240,14 @@ public class TestSpeculativeExecutionWithMRApp { app.waitForState(Service.STATE.STOPPED); } - @Test - public void testSepculateSuccessfulWithUpdateEvents() throws Exception { + /** + * Test speculate successful with update events. + * + * @throws Exception the exception + */ + @Retry + @Test (timeout = 360000) + public void testSpeculateSuccessfulWithUpdateEvents() throws Exception { Clock actualClock = SystemClock.getInstance(); final ControlledClock clock = new ControlledClock(actualClock); @@ -198,7 +294,8 @@ public class TestSpeculativeExecutionWithMRApp { appEventHandler.handle(new TaskAttemptEvent(taskAttempt.getKey(), TaskAttemptEventType.TA_CONTAINER_COMPLETED)); numTasksToFinish--; - app.waitForState(taskAttempt.getValue(), TaskAttemptState.SUCCEEDED); + app.waitForState(taskAttempt.getValue(), TaskAttemptState.KILLED, + TaskAttemptState.SUCCEEDED); } else { // The last task is chosen for speculation TaskAttemptStatus status = @@ -214,13 +311,12 @@ public class TestSpeculativeExecutionWithMRApp { } clock.setTime(System.currentTimeMillis() + 15000); - // give a chance to the speculator thread to run a scan before we proceed - // with updating events - Thread.yield(); + for (Map.Entry task : tasks.entrySet()) { for (Map.Entry taskAttempt : task.getValue() .getAttempts().entrySet()) { - if (taskAttempt.getValue().getState() != TaskAttemptState.SUCCEEDED) { + if (!(taskAttempt.getValue().getState() == TaskAttemptState.SUCCEEDED + || taskAttempt.getValue().getState() == TaskAttemptState.KILLED)) { TaskAttemptStatus status = createTaskAttemptStatus(taskAttempt.getKey(), (float) 0.75, TaskAttemptState.RUNNING);