MAPREDUCE-7252. Handling 0 progress in SimpleExponential task runtime estimator

Signed-off-by: Jonathan Eagles <jeagles@gmail.com>
This commit is contained in:
Ahmed Hussein 2020-01-08 11:08:13 -06:00 committed by Jonathan Eagles
parent 52cc20e9ea
commit cdd6efd3ab
6 changed files with 308 additions and 96 deletions

View File

@ -18,6 +18,11 @@
package org.apache.hadoop.mapreduce.v2.app.speculate; package org.apache.hadoop.mapreduce.v2.app.speculate;
public class DataStatistics { public class DataStatistics {
/**
* factor used to calculate confidence interval within 95%.
*/
private static final double DEFAULT_CI_FACTOR = 1.96;
private int count = 0; private int count = 0;
private double sum = 0; private double sum = 0;
private double sumSquares = 0; private double sumSquares = 0;
@ -25,25 +30,26 @@ public class DataStatistics {
public DataStatistics() { public DataStatistics() {
} }
public DataStatistics(double initNum) { public DataStatistics(final double initNum) {
this.count = 1; this.count = 1;
this.sum = initNum; this.sum = initNum;
this.sumSquares = initNum * initNum; this.sumSquares = initNum * initNum;
} }
public synchronized void add(double newNum) { public synchronized void add(final double newNum) {
this.count++; this.count++;
this.sum += newNum; this.sum += newNum;
this.sumSquares += newNum * newNum; this.sumSquares += newNum * newNum;
} }
public synchronized void updateStatistics(double old, double update) { public synchronized void updateStatistics(final double old,
this.sum += update - old; final double update) {
this.sumSquares += (update * update) - (old * old); this.sum += update - old;
this.sumSquares += (update * update) - (old * old);
} }
public synchronized double mean() { public synchronized double mean() {
return count == 0 ? 0.0 : sum/count; return count == 0 ? 0.0 : sum / count;
} }
public synchronized double var() { public synchronized double var() {
@ -52,14 +58,14 @@ public class DataStatistics {
return 0.0; return 0.0;
} }
double mean = mean(); double mean = mean();
return Math.max((sumSquares/count) - mean * mean, 0.0d); return Math.max((sumSquares / count) - mean * mean, 0.0d);
} }
public synchronized double std() { public synchronized double std() {
return Math.sqrt(this.var()); return Math.sqrt(this.var());
} }
public synchronized double outlier(float sigma) { public synchronized double outlier(final float sigma) {
if (count != 0.0) { if (count != 0.0) {
return mean() + std() * sigma; return mean() + std() * sigma;
} }
@ -78,10 +84,12 @@ public class DataStatistics {
* @return the mean value adding 95% confidence interval * @return the mean value adding 95% confidence interval
*/ */
public synchronized double meanCI() { public synchronized double meanCI() {
if (count <= 1) return 0.0; if (count <= 1) {
return 0.0;
}
double currMean = mean(); double currMean = mean();
double currStd = std(); double currStd = std();
return currMean + (1.96 * currStd / Math.sqrt(count)); return currMean + (DEFAULT_CI_FACTOR * currStd / Math.sqrt(count));
} }
public String toString() { public String toString() {

View File

@ -33,7 +33,22 @@ import org.apache.hadoop.mapreduce.v2.app.speculate.forecast.SimpleExponentialSm
* A task Runtime Estimator based on exponential smoothing. * A task Runtime Estimator based on exponential smoothing.
*/ */
public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase { public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
private final static long DEFAULT_ESTIMATE_RUNTIME = -1L;
/**
* The default value returned by the estimator when no records exist.
*/
private static final long DEFAULT_ESTIMATE_RUNTIME = -1L;
/**
* Given a forecast of value 0.0, it is getting replaced by the default value
* to avoid division by 0.
*/
private static final double DEFAULT_PROGRESS_VALUE = 1E-10;
/**
* Factor used to calculate the confidence interval.
*/
private static final double CONFIDENCE_INTERVAL_FACTOR = 0.25;
/** /**
* Constant time used to calculate the smoothing exponential factor. * Constant time used to calculate the smoothing exponential factor.
@ -53,11 +68,15 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
*/ */
private long stagnatedWindow; private long stagnatedWindow;
/**
* A map of TA Id to the statistic model of smooth exponential.
*/
private final ConcurrentMap<TaskAttemptId, private final ConcurrentMap<TaskAttemptId,
AtomicReference<SimpleExponentialSmoothing>> AtomicReference<SimpleExponentialSmoothing>>
estimates = new ConcurrentHashMap<>(); estimates = new ConcurrentHashMap<>();
private SimpleExponentialSmoothing getForecastEntry(TaskAttemptId attemptID) { private SimpleExponentialSmoothing getForecastEntry(
final TaskAttemptId attemptID) {
AtomicReference<SimpleExponentialSmoothing> entryRef = estimates AtomicReference<SimpleExponentialSmoothing> entryRef = estimates
.get(attemptID); .get(attemptID);
if (entryRef == null) { if (entryRef == null) {
@ -66,13 +85,13 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
return entryRef.get(); return entryRef.get();
} }
private void incorporateReading(TaskAttemptId attemptID, private void incorporateReading(final TaskAttemptId attemptID,
float newRawData, long newTimeStamp) { final float newRawData, final long newTimeStamp) {
SimpleExponentialSmoothing foreCastEntry = getForecastEntry(attemptID); SimpleExponentialSmoothing foreCastEntry = getForecastEntry(attemptID);
if (foreCastEntry == null) { if (foreCastEntry == null) {
Long tStartTime = startTimes.get(attemptID); Long tStartTime = startTimes.get(attemptID);
// skip if the startTime is not set yet // skip if the startTime is not set yet
if(tStartTime == null) { if (tStartTime == null) {
return; return;
} }
estimates.putIfAbsent(attemptID, estimates.putIfAbsent(attemptID,
@ -86,7 +105,8 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
} }
@Override @Override
public void contextualize(Configuration conf, AppContext context) { public void contextualize(final Configuration conf,
final AppContext context) {
super.contextualize(conf, context); super.contextualize(conf, context);
constTime constTime
@ -103,18 +123,16 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
} }
@Override @Override
public long estimatedRuntime(TaskAttemptId id) { public long estimatedRuntime(final TaskAttemptId id) {
SimpleExponentialSmoothing foreCastEntry = getForecastEntry(id); SimpleExponentialSmoothing foreCastEntry = getForecastEntry(id);
if (foreCastEntry == null) { if (foreCastEntry == null) {
return DEFAULT_ESTIMATE_RUNTIME; return DEFAULT_ESTIMATE_RUNTIME;
} }
// TODO: What should we do when estimate is zero double remainingWork = Math
double remainingWork = Math.min(1.0, 1.0 - foreCastEntry.getRawData()); .max(0.0, Math.min(1.0, 1.0 - foreCastEntry.getRawData()));
double forecast = foreCastEntry.getForecast(); double forecast = Math
if (forecast <= 0.0) { .max(DEFAULT_PROGRESS_VALUE, foreCastEntry.getForecast());
return DEFAULT_ESTIMATE_RUNTIME; long remainingTime = (long) (remainingWork / forecast);
}
long remainingTime = (long)(remainingWork / forecast);
long estimatedRuntime = remainingTime long estimatedRuntime = remainingTime
+ foreCastEntry.getTimeStamp() + foreCastEntry.getTimeStamp()
- foreCastEntry.getStartTime(); - foreCastEntry.getStartTime();
@ -122,30 +140,32 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
} }
@Override @Override
public long estimatedNewAttemptRuntime(TaskId id) { public long estimatedNewAttemptRuntime(final TaskId id) {
DataStatistics statistics = dataStatisticsForTask(id); DataStatistics statistics = dataStatisticsForTask(id);
if (statistics == null) { if (statistics == null) {
return -1L; return DEFAULT_ESTIMATE_RUNTIME;
} }
double statsMeanCI = statistics.meanCI(); double statsMeanCI = statistics.meanCI();
double expectedVal = double expectedVal =
statsMeanCI + Math.min(statsMeanCI * 0.25, statistics.std() / 2); statsMeanCI + Math.min(statsMeanCI * CONFIDENCE_INTERVAL_FACTOR,
return (long)(expectedVal); statistics.std() / 2);
return (long) (expectedVal);
} }
@Override @Override
public boolean hasStagnatedProgress(TaskAttemptId id, long timeStamp) { public boolean hasStagnatedProgress(final TaskAttemptId id,
final long timeStamp) {
SimpleExponentialSmoothing foreCastEntry = getForecastEntry(id); SimpleExponentialSmoothing foreCastEntry = getForecastEntry(id);
if(foreCastEntry == null) { if (foreCastEntry == null) {
return false; return false;
} }
return foreCastEntry.isDataStagnated(timeStamp); return foreCastEntry.isDataStagnated(timeStamp);
} }
@Override @Override
public long runtimeEstimateVariance(TaskAttemptId id) { public long runtimeEstimateVariance(final TaskAttemptId id) {
SimpleExponentialSmoothing forecastEntry = getForecastEntry(id); SimpleExponentialSmoothing forecastEntry = getForecastEntry(id);
if (forecastEntry == null) { if (forecastEntry == null) {
return DEFAULT_ESTIMATE_RUNTIME; return DEFAULT_ESTIMATE_RUNTIME;
@ -154,12 +174,13 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
if (forecastEntry.isDefaultForecast(forecast)) { if (forecastEntry.isDefaultForecast(forecast)) {
return DEFAULT_ESTIMATE_RUNTIME; return DEFAULT_ESTIMATE_RUNTIME;
} }
//TODO: What is the best way to measure variance in runtime //TODO What is the best way to measure variance in runtime
return 0L; return 0L;
} }
@Override @Override
public void updateAttempt(TaskAttemptStatus status, long timestamp) { public void updateAttempt(final TaskAttemptStatus status,
final long timestamp) {
super.updateAttempt(status, timestamp); super.updateAttempt(status, timestamp);
TaskAttemptId attemptID = status.id; TaskAttemptId attemptID = status.id;

View File

@ -24,108 +24,145 @@ import java.util.concurrent.atomic.AtomicReference;
* Implementation of the static model for Simple exponential smoothing. * Implementation of the static model for Simple exponential smoothing.
*/ */
public class SimpleExponentialSmoothing { public class SimpleExponentialSmoothing {
public final static double DEFAULT_FORECAST = -1.0; private static final double DEFAULT_FORECAST = -1.0;
private final int kMinimumReads; private final int kMinimumReads;
private final long kStagnatedWindow; private final long kStagnatedWindow;
private final long startTime; private final long startTime;
private long timeConstant; private long timeConstant;
/**
* Holds reference to the current forecast record.
*/
private AtomicReference<ForecastRecord> forecastRefEntry; private AtomicReference<ForecastRecord> forecastRefEntry;
public static SimpleExponentialSmoothing createForecast(long timeConstant, public static SimpleExponentialSmoothing createForecast(
int skipCnt, long stagnatedWindow, long timeStamp) { final long timeConstant,
final int skipCnt, final long stagnatedWindow, final long timeStamp) {
return new SimpleExponentialSmoothing(timeConstant, skipCnt, return new SimpleExponentialSmoothing(timeConstant, skipCnt,
stagnatedWindow, timeStamp); stagnatedWindow, timeStamp);
} }
SimpleExponentialSmoothing(long ktConstant, int skipCnt, SimpleExponentialSmoothing(final long ktConstant, final int skipCnt,
long stagnatedWindow, long timeStamp) { final long stagnatedWindow, final long timeStamp) {
kMinimumReads = skipCnt; this.kMinimumReads = skipCnt;
kStagnatedWindow = stagnatedWindow; this.kStagnatedWindow = stagnatedWindow;
this.timeConstant = ktConstant; this.timeConstant = ktConstant;
this.startTime = timeStamp; this.startTime = timeStamp;
this.forecastRefEntry = new AtomicReference<ForecastRecord>(null); this.forecastRefEntry = new AtomicReference<ForecastRecord>(null);
} }
private class ForecastRecord { private class ForecastRecord {
private double alpha; private final double alpha;
private long timeStamp; private final long timeStamp;
private double sample; private final double sample;
private double rawData; private final double rawData;
private double forecast; private double forecast;
private double sseError; private final double sseError;
private long myIndex; private final long myIndex;
private ForecastRecord prevRec;
ForecastRecord(double forecast, double rawData, long timeStamp) { ForecastRecord(final double currForecast, final double currRawData,
this(0.0, forecast, rawData, forecast, timeStamp, 0.0, 0); final long currTimeStamp) {
this(0.0, currForecast, currRawData, currForecast, currTimeStamp, 0.0, 0);
} }
ForecastRecord(double alpha, double sample, double rawData, ForecastRecord(final double alphaVal, final double currSample,
double forecast, long timeStamp, double accError, long index) { final double currRawData,
this.timeStamp = timeStamp; final double currForecast, final long currTimeStamp,
this.alpha = alpha; final double accError,
this.sseError = 0.0; final long index) {
this.sample = sample; this.timeStamp = currTimeStamp;
this.forecast = forecast; this.alpha = alphaVal;
this.rawData = rawData; this.sample = currSample;
this.forecast = currForecast;
this.rawData = currRawData;
this.sseError = accError; this.sseError = accError;
this.myIndex = index; this.myIndex = index;
} }
private double preProcessRawData(double rData, long newTime) { private ForecastRecord createForecastRecord(final double alphaVal,
final double currSample,
final double currRawData,
final double currForecast, final long currTimeStamp,
final double accError,
final long index,
final ForecastRecord prev) {
ForecastRecord forecastRec =
new ForecastRecord(alphaVal, currSample, currRawData, currForecast,
currTimeStamp, accError, index);
forecastRec.prevRec = prev;
return forecastRec;
}
private double preProcessRawData(final double rData, final long newTime) {
return processRawData(this.rawData, this.timeStamp, rData, newTime); return processRawData(this.rawData, this.timeStamp, rData, newTime);
} }
public ForecastRecord append(long newTimeStamp, double rData) { public ForecastRecord append(final long newTimeStamp, final double rData) {
if (this.timeStamp > newTimeStamp) { if (this.timeStamp >= newTimeStamp
&& Double.compare(this.rawData, rData) >= 0) {
// progress reported twice. Do nothing.
return this; return this;
} }
double newSample = preProcessRawData(rData, newTimeStamp); ForecastRecord refRecord = this;
if (newTimeStamp == this.timeStamp) {
// we need to restore old value if possible
if (this.prevRec != null) {
refRecord = this.prevRec;
}
}
double newSample = refRecord.preProcessRawData(rData, newTimeStamp);
long deltaTime = this.timeStamp - newTimeStamp; long deltaTime = this.timeStamp - newTimeStamp;
if (this.myIndex == kMinimumReads) { if (refRecord.myIndex == kMinimumReads) {
timeConstant = Math.max(timeConstant, newTimeStamp - startTime); timeConstant = Math.max(timeConstant, newTimeStamp - startTime);
} }
double smoothFactor = double smoothFactor =
1 - Math.exp(((double) deltaTime) / timeConstant); 1 - Math.exp(((double) deltaTime) / timeConstant);
double forecastVal = double forecastVal =
smoothFactor * newSample + (1.0 - smoothFactor) * this.forecast; smoothFactor * newSample + (1.0 - smoothFactor) * refRecord.forecast;
double newSSEError = double newSSEError =
this.sseError + Math.pow(newSample - this.forecast, 2); refRecord.sseError + Math.pow(newSample - refRecord.forecast, 2);
return new ForecastRecord(smoothFactor, newSample, rData, forecastVal, return refRecord
newTimeStamp, newSSEError, this.myIndex + 1); .createForecastRecord(smoothFactor, newSample, rData, forecastVal,
newTimeStamp, newSSEError, refRecord.myIndex + 1, refRecord);
} }
} }
public boolean isDataStagnated(long timeStamp) { /**
* checks if the task is hanging up.
* @param timeStamp current time of the scan.
* @return true if we have number of samples > kMinimumReads and the record
* timestamp has expired.
*/
public boolean isDataStagnated(final long timeStamp) {
ForecastRecord rec = forecastRefEntry.get(); ForecastRecord rec = forecastRefEntry.get();
if (rec != null && rec.myIndex <= kMinimumReads) { if (rec != null && rec.myIndex > kMinimumReads) {
return (rec.timeStamp + kStagnatedWindow) < timeStamp; return (rec.timeStamp + kStagnatedWindow) > timeStamp;
} }
return false; return false;
} }
static double processRawData(double oldRawData, long oldTime, static double processRawData(final double oldRawData, final long oldTime,
double newRawData, long newTime) { final double newRawData, final long newTime) {
double rate = (newRawData - oldRawData) / (newTime - oldTime); double rate = (newRawData - oldRawData) / (newTime - oldTime);
return rate; return rate;
} }
public void incorporateReading(long timeStamp, double rawData) { public void incorporateReading(final long timeStamp,
final double currRawData) {
ForecastRecord oldRec = forecastRefEntry.get(); ForecastRecord oldRec = forecastRefEntry.get();
if (oldRec == null) { if (oldRec == null) {
double oldForecast = double oldForecast =
processRawData(0, startTime, rawData, timeStamp); processRawData(0, startTime, currRawData, timeStamp);
forecastRefEntry.compareAndSet(null, forecastRefEntry.compareAndSet(null,
new ForecastRecord(oldForecast, 0.0, startTime)); new ForecastRecord(oldForecast, 0.0, startTime));
incorporateReading(timeStamp, rawData); incorporateReading(timeStamp, currRawData);
return; return;
} }
while (!forecastRefEntry.compareAndSet(oldRec, oldRec.append(timeStamp, while (!forecastRefEntry.compareAndSet(oldRec, oldRec.append(timeStamp,
rawData))) { currRawData))) {
oldRec = forecastRefEntry.get(); oldRec = forecastRefEntry.get();
} }
} }
public double getForecast() { public double getForecast() {
@ -136,7 +173,7 @@ public class SimpleExponentialSmoothing {
return DEFAULT_FORECAST; return DEFAULT_FORECAST;
} }
public boolean isDefaultForecast(double value) { public boolean isDefaultForecast(final double value) {
return value == DEFAULT_FORECAST; return value == DEFAULT_FORECAST;
} }
@ -148,7 +185,7 @@ public class SimpleExponentialSmoothing {
return DEFAULT_FORECAST; return DEFAULT_FORECAST;
} }
public boolean isErrorWithinBound(double bound) { public boolean isErrorWithinBound(final double bound) {
double squaredErr = getSSE(); double squaredErr = getSSE();
if (squaredErr < 0) { if (squaredErr < 0) {
return false; return false;
@ -185,8 +222,8 @@ public class SimpleExponentialSmoothing {
String res = "NULL"; String res = "NULL";
ForecastRecord rec = forecastRefEntry.get(); ForecastRecord rec = forecastRefEntry.get();
if (rec != null) { if (rec != null) {
res = "rec.index = " + rec.myIndex + ", forecast t: " + rec.timeStamp + res = "rec.index = " + rec.myIndex + ", forecast t: " + rec.timeStamp
", forecast: " + rec.forecast + ", forecast: " + rec.forecast
+ ", sample: " + rec.sample + ", raw: " + rec.rawData + ", error: " + ", sample: " + rec.sample + ", raw: " + rec.rawData + ", error: "
+ rec.sseError + ", alpha: " + rec.alpha; + rec.sseError + ", alpha: " + rec.alpha;
} }

View File

@ -0,0 +1,20 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
@InterfaceAudience.Private
package org.apache.hadoop.mapreduce.v2.app.speculate.forecast;
import org.apache.hadoop.classification.InterfaceAudience;

View File

@ -22,8 +22,11 @@ import java.io.File;
import java.io.FileOutputStream; import java.io.FileOutputStream;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.util.Arrays;
import java.util.EnumSet; import java.util.EnumSet;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileContext; import org.apache.hadoop.fs.FileContext;
import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.FileSystem;
@ -372,18 +375,45 @@ public class MRApp extends MRAppMaster {
TaskAttemptReport report = attempt.getReport(); TaskAttemptReport report = attempt.getReport();
while (!finalState.equals(report.getTaskAttemptState()) && while (!finalState.equals(report.getTaskAttemptState()) &&
timeoutSecs++ < 20) { timeoutSecs++ < 20) {
System.out.println("TaskAttempt State is : " + report.getTaskAttemptState() + System.out.println(
" Waiting for state : " + finalState + "TaskAttempt " + attempt.getID().toString() + " State is : "
" progress : " + report.getProgress()); + report.getTaskAttemptState()
+ " Waiting for state : " + finalState
+ " progress : " + report.getProgress());
report = attempt.getReport(); report = attempt.getReport();
Thread.sleep(500); Thread.sleep(500);
} }
System.out.println("TaskAttempt State is : " + report.getTaskAttemptState()); System.out.println("TaskAttempt State is : "
+ report.getTaskAttemptState());
Assert.assertEquals("TaskAttempt state is not correct (timedout)", Assert.assertEquals("TaskAttempt state is not correct (timedout)",
finalState, finalState,
report.getTaskAttemptState()); report.getTaskAttemptState());
} }
public void waitForState(TaskAttempt attempt,
TaskAttemptState...finalStates) throws Exception {
int timeoutSecs = 0;
TaskAttemptReport report = attempt.getReport();
List<TaskAttemptState> targetStates = Arrays.asList(finalStates);
String statesValues = targetStates.stream().map(Object::toString).collect(
Collectors.joining(","));
while (!targetStates.contains(report.getTaskAttemptState()) &&
timeoutSecs++ < 20) {
System.out.println(
"TaskAttempt " + attempt.getID().toString() + " State is : "
+ report.getTaskAttemptState()
+ " Waiting for states: " + statesValues
+ ". curent state is : " + report.getTaskAttemptState()
+ ". progress : " + report.getProgress());
report = attempt.getReport();
Thread.sleep(500);
}
System.out.println("TaskAttempt State is : "
+ report.getTaskAttemptState());
Assert.assertTrue("TaskAttempt state is not correct (timedout)",
targetStates.contains(report.getTaskAttemptState()));
}
public void waitForState(Task task, TaskState finalState) throws Exception { public void waitForState(Task task, TaskState finalState) throws Exception {
int timeoutSecs = 0; int timeoutSecs = 0;
TaskReport report = task.getReport(); TaskReport report = task.getReport();

View File

@ -18,11 +18,14 @@
package org.apache.hadoop.mapreduce.v2; package org.apache.hadoop.mapreduce.v2;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Iterator; import java.util.Iterator;
import java.util.Map; import java.util.Map;
import java.util.Random; import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import org.apache.hadoop.mapreduce.MRJobConfig; import org.apache.hadoop.mapreduce.MRJobConfig;
@ -50,19 +53,94 @@ import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.hadoop.yarn.util.Clock; import org.apache.hadoop.yarn.util.Clock;
import org.apache.hadoop.yarn.util.ControlledClock; import org.apache.hadoop.yarn.util.ControlledClock;
import org.apache.hadoop.yarn.util.SystemClock; import org.apache.hadoop.yarn.util.SystemClock;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import com.google.common.base.Supplier; import com.google.common.base.Supplier;
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.runners.Parameterized;
import org.junit.runners.model.Statement;
/**
* The type Test speculative execution with mr app.
* It test the speculation behavior given a list of estimator classes.
*/
@SuppressWarnings({ "unchecked", "rawtypes" }) @SuppressWarnings({ "unchecked", "rawtypes" })
@RunWith(Parameterized.class) @RunWith(Parameterized.class)
public class TestSpeculativeExecutionWithMRApp { public class TestSpeculativeExecutionWithMRApp {
/** Number of times to re-try the failing tests. */
private static final int ASSERT_SPECULATIONS_COUNT_RETRIES = 3;
private static final int NUM_MAPPERS = 5; private static final int NUM_MAPPERS = 5;
private static final int NUM_REDUCERS = 0; private static final int NUM_REDUCERS = 0;
/**
* Speculation has non-deterministic behavior due to racing and timing. Use
* retry to verify that junit tests can pass.
*/
@Retention(RetentionPolicy.RUNTIME)
public @interface Retry {}
/**
* The type Retry rule.
*/
class RetryRule implements TestRule {
private AtomicInteger retryCount;
/**
* Instantiates a new Retry rule.
*
* @param retries the retries
*/
RetryRule(int retries) {
super();
this.retryCount = new AtomicInteger(retries);
}
@Override
public Statement apply(final Statement base,
final Description description) {
return new Statement() {
@Override
public void evaluate() throws Throwable {
Throwable caughtThrowable = null;
while (retryCount.getAndDecrement() > 0) {
try {
base.evaluate();
return;
} catch (Throwable t) {
if (retryCount.get() > 0 &&
description.getAnnotation(Retry.class) != null) {
caughtThrowable = t;
System.out.println(
description.getDisplayName() +
": Failed, " +
retryCount.toString() +
" retries remain");
} else {
throw caughtThrowable;
}
}
}
}
};
}
}
/**
* The Rule.
*/
@Rule
public RetryRule rule = new RetryRule(ASSERT_SPECULATIONS_COUNT_RETRIES);
/**
* Gets test parameters.
*
* @return the test parameters
*/
@Parameterized.Parameters(name = "{index}: TaskEstimator(EstimatorClass {0})") @Parameterized.Parameters(name = "{index}: TaskEstimator(EstimatorClass {0})")
public static Collection<Object[]> getTestParameters() { public static Collection<Object[]> getTestParameters() {
return Arrays.asList(new Object[][] { return Arrays.asList(new Object[][] {
@ -73,12 +151,23 @@ public class TestSpeculativeExecutionWithMRApp {
private Class<? extends TaskRuntimeEstimator> estimatorClass; private Class<? extends TaskRuntimeEstimator> estimatorClass;
/**
* Instantiates a new Test speculative execution with mr app.
*
* @param estimatorKlass the estimator klass
*/
public TestSpeculativeExecutionWithMRApp( public TestSpeculativeExecutionWithMRApp(
Class<? extends TaskRuntimeEstimator> estimatorKlass) { Class<? extends TaskRuntimeEstimator> estimatorKlass) {
this.estimatorClass = estimatorKlass; this.estimatorClass = estimatorKlass;
} }
@Test /**
* Test speculate successful without update events.
*
* @throws Exception the exception
*/
@Retry
@Test (timeout = 360000)
public void testSpeculateSuccessfulWithoutUpdateEvents() throws Exception { public void testSpeculateSuccessfulWithoutUpdateEvents() throws Exception {
Clock actualClock = SystemClock.getInstance(); Clock actualClock = SystemClock.getInstance();
@ -128,7 +217,8 @@ public class TestSpeculativeExecutionWithMRApp {
TaskAttemptEventType.TA_DONE)); TaskAttemptEventType.TA_DONE));
appEventHandler.handle(new TaskAttemptEvent(taskAttempt.getKey(), appEventHandler.handle(new TaskAttemptEvent(taskAttempt.getKey(),
TaskAttemptEventType.TA_CONTAINER_COMPLETED)); TaskAttemptEventType.TA_CONTAINER_COMPLETED));
app.waitForState(taskAttempt.getValue(), TaskAttemptState.SUCCEEDED); app.waitForState(taskAttempt.getValue(), TaskAttemptState.SUCCEEDED,
TaskAttemptState.KILLED);
} }
} }
} }
@ -150,8 +240,14 @@ public class TestSpeculativeExecutionWithMRApp {
app.waitForState(Service.STATE.STOPPED); app.waitForState(Service.STATE.STOPPED);
} }
@Test /**
public void testSepculateSuccessfulWithUpdateEvents() throws Exception { * Test speculate successful with update events.
*
* @throws Exception the exception
*/
@Retry
@Test (timeout = 360000)
public void testSpeculateSuccessfulWithUpdateEvents() throws Exception {
Clock actualClock = SystemClock.getInstance(); Clock actualClock = SystemClock.getInstance();
final ControlledClock clock = new ControlledClock(actualClock); final ControlledClock clock = new ControlledClock(actualClock);
@ -198,7 +294,8 @@ public class TestSpeculativeExecutionWithMRApp {
appEventHandler.handle(new TaskAttemptEvent(taskAttempt.getKey(), appEventHandler.handle(new TaskAttemptEvent(taskAttempt.getKey(),
TaskAttemptEventType.TA_CONTAINER_COMPLETED)); TaskAttemptEventType.TA_CONTAINER_COMPLETED));
numTasksToFinish--; numTasksToFinish--;
app.waitForState(taskAttempt.getValue(), TaskAttemptState.SUCCEEDED); app.waitForState(taskAttempt.getValue(), TaskAttemptState.KILLED,
TaskAttemptState.SUCCEEDED);
} else { } else {
// The last task is chosen for speculation // The last task is chosen for speculation
TaskAttemptStatus status = TaskAttemptStatus status =
@ -214,13 +311,12 @@ public class TestSpeculativeExecutionWithMRApp {
} }
clock.setTime(System.currentTimeMillis() + 15000); clock.setTime(System.currentTimeMillis() + 15000);
// give a chance to the speculator thread to run a scan before we proceed
// with updating events
Thread.yield();
for (Map.Entry<TaskId, Task> task : tasks.entrySet()) { for (Map.Entry<TaskId, Task> task : tasks.entrySet()) {
for (Map.Entry<TaskAttemptId, TaskAttempt> taskAttempt : task.getValue() for (Map.Entry<TaskAttemptId, TaskAttempt> taskAttempt : task.getValue()
.getAttempts().entrySet()) { .getAttempts().entrySet()) {
if (taskAttempt.getValue().getState() != TaskAttemptState.SUCCEEDED) { if (!(taskAttempt.getValue().getState() == TaskAttemptState.SUCCEEDED
|| taskAttempt.getValue().getState() == TaskAttemptState.KILLED)) {
TaskAttemptStatus status = TaskAttemptStatus status =
createTaskAttemptStatus(taskAttempt.getKey(), (float) 0.75, createTaskAttemptStatus(taskAttempt.getKey(), (float) 0.75,
TaskAttemptState.RUNNING); TaskAttemptState.RUNNING);