MAPREDUCE-7252. Handling 0 progress in SimpleExponential task runtime estimator
Signed-off-by: Jonathan Eagles <jeagles@gmail.com>
This commit is contained in:
parent
52cc20e9ea
commit
cdd6efd3ab
|
@ -18,6 +18,11 @@
|
|||
package org.apache.hadoop.mapreduce.v2.app.speculate;
|
||||
|
||||
public class DataStatistics {
|
||||
|
||||
/**
|
||||
* factor used to calculate confidence interval within 95%.
|
||||
*/
|
||||
private static final double DEFAULT_CI_FACTOR = 1.96;
|
||||
private int count = 0;
|
||||
private double sum = 0;
|
||||
private double sumSquares = 0;
|
||||
|
@ -25,25 +30,26 @@ public class DataStatistics {
|
|||
public DataStatistics() {
|
||||
}
|
||||
|
||||
public DataStatistics(double initNum) {
|
||||
public DataStatistics(final double initNum) {
|
||||
this.count = 1;
|
||||
this.sum = initNum;
|
||||
this.sumSquares = initNum * initNum;
|
||||
}
|
||||
|
||||
public synchronized void add(double newNum) {
|
||||
public synchronized void add(final double newNum) {
|
||||
this.count++;
|
||||
this.sum += newNum;
|
||||
this.sumSquares += newNum * newNum;
|
||||
}
|
||||
|
||||
public synchronized void updateStatistics(double old, double update) {
|
||||
public synchronized void updateStatistics(final double old,
|
||||
final double update) {
|
||||
this.sum += update - old;
|
||||
this.sumSquares += (update * update) - (old * old);
|
||||
}
|
||||
|
||||
public synchronized double mean() {
|
||||
return count == 0 ? 0.0 : sum/count;
|
||||
return count == 0 ? 0.0 : sum / count;
|
||||
}
|
||||
|
||||
public synchronized double var() {
|
||||
|
@ -52,14 +58,14 @@ public class DataStatistics {
|
|||
return 0.0;
|
||||
}
|
||||
double mean = mean();
|
||||
return Math.max((sumSquares/count) - mean * mean, 0.0d);
|
||||
return Math.max((sumSquares / count) - mean * mean, 0.0d);
|
||||
}
|
||||
|
||||
public synchronized double std() {
|
||||
return Math.sqrt(this.var());
|
||||
}
|
||||
|
||||
public synchronized double outlier(float sigma) {
|
||||
public synchronized double outlier(final float sigma) {
|
||||
if (count != 0.0) {
|
||||
return mean() + std() * sigma;
|
||||
}
|
||||
|
@ -78,10 +84,12 @@ public class DataStatistics {
|
|||
* @return the mean value adding 95% confidence interval
|
||||
*/
|
||||
public synchronized double meanCI() {
|
||||
if (count <= 1) return 0.0;
|
||||
if (count <= 1) {
|
||||
return 0.0;
|
||||
}
|
||||
double currMean = mean();
|
||||
double currStd = std();
|
||||
return currMean + (1.96 * currStd / Math.sqrt(count));
|
||||
return currMean + (DEFAULT_CI_FACTOR * currStd / Math.sqrt(count));
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
|
|
|
@ -33,7 +33,22 @@ import org.apache.hadoop.mapreduce.v2.app.speculate.forecast.SimpleExponentialSm
|
|||
* A task Runtime Estimator based on exponential smoothing.
|
||||
*/
|
||||
public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
|
||||
private final static long DEFAULT_ESTIMATE_RUNTIME = -1L;
|
||||
|
||||
/**
|
||||
* The default value returned by the estimator when no records exist.
|
||||
*/
|
||||
private static final long DEFAULT_ESTIMATE_RUNTIME = -1L;
|
||||
|
||||
/**
|
||||
* Given a forecast of value 0.0, it is getting replaced by the default value
|
||||
* to avoid division by 0.
|
||||
*/
|
||||
private static final double DEFAULT_PROGRESS_VALUE = 1E-10;
|
||||
|
||||
/**
|
||||
* Factor used to calculate the confidence interval.
|
||||
*/
|
||||
private static final double CONFIDENCE_INTERVAL_FACTOR = 0.25;
|
||||
|
||||
/**
|
||||
* Constant time used to calculate the smoothing exponential factor.
|
||||
|
@ -53,11 +68,15 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
|
|||
*/
|
||||
private long stagnatedWindow;
|
||||
|
||||
/**
|
||||
* A map of TA Id to the statistic model of smooth exponential.
|
||||
*/
|
||||
private final ConcurrentMap<TaskAttemptId,
|
||||
AtomicReference<SimpleExponentialSmoothing>>
|
||||
estimates = new ConcurrentHashMap<>();
|
||||
|
||||
private SimpleExponentialSmoothing getForecastEntry(TaskAttemptId attemptID) {
|
||||
private SimpleExponentialSmoothing getForecastEntry(
|
||||
final TaskAttemptId attemptID) {
|
||||
AtomicReference<SimpleExponentialSmoothing> entryRef = estimates
|
||||
.get(attemptID);
|
||||
if (entryRef == null) {
|
||||
|
@ -66,13 +85,13 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
|
|||
return entryRef.get();
|
||||
}
|
||||
|
||||
private void incorporateReading(TaskAttemptId attemptID,
|
||||
float newRawData, long newTimeStamp) {
|
||||
private void incorporateReading(final TaskAttemptId attemptID,
|
||||
final float newRawData, final long newTimeStamp) {
|
||||
SimpleExponentialSmoothing foreCastEntry = getForecastEntry(attemptID);
|
||||
if (foreCastEntry == null) {
|
||||
Long tStartTime = startTimes.get(attemptID);
|
||||
// skip if the startTime is not set yet
|
||||
if(tStartTime == null) {
|
||||
if (tStartTime == null) {
|
||||
return;
|
||||
}
|
||||
estimates.putIfAbsent(attemptID,
|
||||
|
@ -86,7 +105,8 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void contextualize(Configuration conf, AppContext context) {
|
||||
public void contextualize(final Configuration conf,
|
||||
final AppContext context) {
|
||||
super.contextualize(conf, context);
|
||||
|
||||
constTime
|
||||
|
@ -103,18 +123,16 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
|
|||
}
|
||||
|
||||
@Override
|
||||
public long estimatedRuntime(TaskAttemptId id) {
|
||||
public long estimatedRuntime(final TaskAttemptId id) {
|
||||
SimpleExponentialSmoothing foreCastEntry = getForecastEntry(id);
|
||||
if (foreCastEntry == null) {
|
||||
return DEFAULT_ESTIMATE_RUNTIME;
|
||||
}
|
||||
// TODO: What should we do when estimate is zero
|
||||
double remainingWork = Math.min(1.0, 1.0 - foreCastEntry.getRawData());
|
||||
double forecast = foreCastEntry.getForecast();
|
||||
if (forecast <= 0.0) {
|
||||
return DEFAULT_ESTIMATE_RUNTIME;
|
||||
}
|
||||
long remainingTime = (long)(remainingWork / forecast);
|
||||
double remainingWork = Math
|
||||
.max(0.0, Math.min(1.0, 1.0 - foreCastEntry.getRawData()));
|
||||
double forecast = Math
|
||||
.max(DEFAULT_PROGRESS_VALUE, foreCastEntry.getForecast());
|
||||
long remainingTime = (long) (remainingWork / forecast);
|
||||
long estimatedRuntime = remainingTime
|
||||
+ foreCastEntry.getTimeStamp()
|
||||
- foreCastEntry.getStartTime();
|
||||
|
@ -122,30 +140,32 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
|
|||
}
|
||||
|
||||
@Override
|
||||
public long estimatedNewAttemptRuntime(TaskId id) {
|
||||
public long estimatedNewAttemptRuntime(final TaskId id) {
|
||||
DataStatistics statistics = dataStatisticsForTask(id);
|
||||
|
||||
if (statistics == null) {
|
||||
return -1L;
|
||||
return DEFAULT_ESTIMATE_RUNTIME;
|
||||
}
|
||||
|
||||
double statsMeanCI = statistics.meanCI();
|
||||
double expectedVal =
|
||||
statsMeanCI + Math.min(statsMeanCI * 0.25, statistics.std() / 2);
|
||||
return (long)(expectedVal);
|
||||
statsMeanCI + Math.min(statsMeanCI * CONFIDENCE_INTERVAL_FACTOR,
|
||||
statistics.std() / 2);
|
||||
return (long) (expectedVal);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasStagnatedProgress(TaskAttemptId id, long timeStamp) {
|
||||
public boolean hasStagnatedProgress(final TaskAttemptId id,
|
||||
final long timeStamp) {
|
||||
SimpleExponentialSmoothing foreCastEntry = getForecastEntry(id);
|
||||
if(foreCastEntry == null) {
|
||||
if (foreCastEntry == null) {
|
||||
return false;
|
||||
}
|
||||
return foreCastEntry.isDataStagnated(timeStamp);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long runtimeEstimateVariance(TaskAttemptId id) {
|
||||
public long runtimeEstimateVariance(final TaskAttemptId id) {
|
||||
SimpleExponentialSmoothing forecastEntry = getForecastEntry(id);
|
||||
if (forecastEntry == null) {
|
||||
return DEFAULT_ESTIMATE_RUNTIME;
|
||||
|
@ -154,12 +174,13 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
|
|||
if (forecastEntry.isDefaultForecast(forecast)) {
|
||||
return DEFAULT_ESTIMATE_RUNTIME;
|
||||
}
|
||||
//TODO: What is the best way to measure variance in runtime
|
||||
//TODO What is the best way to measure variance in runtime
|
||||
return 0L;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateAttempt(TaskAttemptStatus status, long timestamp) {
|
||||
public void updateAttempt(final TaskAttemptStatus status,
|
||||
final long timestamp) {
|
||||
super.updateAttempt(status, timestamp);
|
||||
TaskAttemptId attemptID = status.id;
|
||||
|
||||
|
|
|
@ -24,108 +24,145 @@ import java.util.concurrent.atomic.AtomicReference;
|
|||
* Implementation of the static model for Simple exponential smoothing.
|
||||
*/
|
||||
public class SimpleExponentialSmoothing {
|
||||
public final static double DEFAULT_FORECAST = -1.0;
|
||||
private static final double DEFAULT_FORECAST = -1.0;
|
||||
private final int kMinimumReads;
|
||||
private final long kStagnatedWindow;
|
||||
private final long startTime;
|
||||
private long timeConstant;
|
||||
|
||||
/**
|
||||
* Holds reference to the current forecast record.
|
||||
*/
|
||||
private AtomicReference<ForecastRecord> forecastRefEntry;
|
||||
|
||||
public static SimpleExponentialSmoothing createForecast(long timeConstant,
|
||||
int skipCnt, long stagnatedWindow, long timeStamp) {
|
||||
public static SimpleExponentialSmoothing createForecast(
|
||||
final long timeConstant,
|
||||
final int skipCnt, final long stagnatedWindow, final long timeStamp) {
|
||||
return new SimpleExponentialSmoothing(timeConstant, skipCnt,
|
||||
stagnatedWindow, timeStamp);
|
||||
}
|
||||
|
||||
SimpleExponentialSmoothing(long ktConstant, int skipCnt,
|
||||
long stagnatedWindow, long timeStamp) {
|
||||
kMinimumReads = skipCnt;
|
||||
kStagnatedWindow = stagnatedWindow;
|
||||
SimpleExponentialSmoothing(final long ktConstant, final int skipCnt,
|
||||
final long stagnatedWindow, final long timeStamp) {
|
||||
this.kMinimumReads = skipCnt;
|
||||
this.kStagnatedWindow = stagnatedWindow;
|
||||
this.timeConstant = ktConstant;
|
||||
this.startTime = timeStamp;
|
||||
this.forecastRefEntry = new AtomicReference<ForecastRecord>(null);
|
||||
}
|
||||
|
||||
private class ForecastRecord {
|
||||
private double alpha;
|
||||
private long timeStamp;
|
||||
private double sample;
|
||||
private double rawData;
|
||||
private final double alpha;
|
||||
private final long timeStamp;
|
||||
private final double sample;
|
||||
private final double rawData;
|
||||
private double forecast;
|
||||
private double sseError;
|
||||
private long myIndex;
|
||||
private final double sseError;
|
||||
private final long myIndex;
|
||||
private ForecastRecord prevRec;
|
||||
|
||||
ForecastRecord(double forecast, double rawData, long timeStamp) {
|
||||
this(0.0, forecast, rawData, forecast, timeStamp, 0.0, 0);
|
||||
ForecastRecord(final double currForecast, final double currRawData,
|
||||
final long currTimeStamp) {
|
||||
this(0.0, currForecast, currRawData, currForecast, currTimeStamp, 0.0, 0);
|
||||
}
|
||||
|
||||
ForecastRecord(double alpha, double sample, double rawData,
|
||||
double forecast, long timeStamp, double accError, long index) {
|
||||
this.timeStamp = timeStamp;
|
||||
this.alpha = alpha;
|
||||
this.sseError = 0.0;
|
||||
this.sample = sample;
|
||||
this.forecast = forecast;
|
||||
this.rawData = rawData;
|
||||
ForecastRecord(final double alphaVal, final double currSample,
|
||||
final double currRawData,
|
||||
final double currForecast, final long currTimeStamp,
|
||||
final double accError,
|
||||
final long index) {
|
||||
this.timeStamp = currTimeStamp;
|
||||
this.alpha = alphaVal;
|
||||
this.sample = currSample;
|
||||
this.forecast = currForecast;
|
||||
this.rawData = currRawData;
|
||||
this.sseError = accError;
|
||||
this.myIndex = index;
|
||||
}
|
||||
|
||||
private double preProcessRawData(double rData, long newTime) {
|
||||
private ForecastRecord createForecastRecord(final double alphaVal,
|
||||
final double currSample,
|
||||
final double currRawData,
|
||||
final double currForecast, final long currTimeStamp,
|
||||
final double accError,
|
||||
final long index,
|
||||
final ForecastRecord prev) {
|
||||
ForecastRecord forecastRec =
|
||||
new ForecastRecord(alphaVal, currSample, currRawData, currForecast,
|
||||
currTimeStamp, accError, index);
|
||||
forecastRec.prevRec = prev;
|
||||
return forecastRec;
|
||||
}
|
||||
|
||||
private double preProcessRawData(final double rData, final long newTime) {
|
||||
return processRawData(this.rawData, this.timeStamp, rData, newTime);
|
||||
}
|
||||
|
||||
public ForecastRecord append(long newTimeStamp, double rData) {
|
||||
if (this.timeStamp > newTimeStamp) {
|
||||
public ForecastRecord append(final long newTimeStamp, final double rData) {
|
||||
if (this.timeStamp >= newTimeStamp
|
||||
&& Double.compare(this.rawData, rData) >= 0) {
|
||||
// progress reported twice. Do nothing.
|
||||
return this;
|
||||
}
|
||||
double newSample = preProcessRawData(rData, newTimeStamp);
|
||||
ForecastRecord refRecord = this;
|
||||
if (newTimeStamp == this.timeStamp) {
|
||||
// we need to restore old value if possible
|
||||
if (this.prevRec != null) {
|
||||
refRecord = this.prevRec;
|
||||
}
|
||||
}
|
||||
double newSample = refRecord.preProcessRawData(rData, newTimeStamp);
|
||||
long deltaTime = this.timeStamp - newTimeStamp;
|
||||
if (this.myIndex == kMinimumReads) {
|
||||
if (refRecord.myIndex == kMinimumReads) {
|
||||
timeConstant = Math.max(timeConstant, newTimeStamp - startTime);
|
||||
}
|
||||
double smoothFactor =
|
||||
1 - Math.exp(((double) deltaTime) / timeConstant);
|
||||
double forecastVal =
|
||||
smoothFactor * newSample + (1.0 - smoothFactor) * this.forecast;
|
||||
smoothFactor * newSample + (1.0 - smoothFactor) * refRecord.forecast;
|
||||
double newSSEError =
|
||||
this.sseError + Math.pow(newSample - this.forecast, 2);
|
||||
return new ForecastRecord(smoothFactor, newSample, rData, forecastVal,
|
||||
newTimeStamp, newSSEError, this.myIndex + 1);
|
||||
refRecord.sseError + Math.pow(newSample - refRecord.forecast, 2);
|
||||
return refRecord
|
||||
.createForecastRecord(smoothFactor, newSample, rData, forecastVal,
|
||||
newTimeStamp, newSSEError, refRecord.myIndex + 1, refRecord);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public boolean isDataStagnated(long timeStamp) {
|
||||
/**
|
||||
* checks if the task is hanging up.
|
||||
* @param timeStamp current time of the scan.
|
||||
* @return true if we have number of samples > kMinimumReads and the record
|
||||
* timestamp has expired.
|
||||
*/
|
||||
public boolean isDataStagnated(final long timeStamp) {
|
||||
ForecastRecord rec = forecastRefEntry.get();
|
||||
if (rec != null && rec.myIndex <= kMinimumReads) {
|
||||
return (rec.timeStamp + kStagnatedWindow) < timeStamp;
|
||||
if (rec != null && rec.myIndex > kMinimumReads) {
|
||||
return (rec.timeStamp + kStagnatedWindow) > timeStamp;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static double processRawData(double oldRawData, long oldTime,
|
||||
double newRawData, long newTime) {
|
||||
static double processRawData(final double oldRawData, final long oldTime,
|
||||
final double newRawData, final long newTime) {
|
||||
double rate = (newRawData - oldRawData) / (newTime - oldTime);
|
||||
return rate;
|
||||
}
|
||||
|
||||
public void incorporateReading(long timeStamp, double rawData) {
|
||||
public void incorporateReading(final long timeStamp,
|
||||
final double currRawData) {
|
||||
ForecastRecord oldRec = forecastRefEntry.get();
|
||||
if (oldRec == null) {
|
||||
double oldForecast =
|
||||
processRawData(0, startTime, rawData, timeStamp);
|
||||
processRawData(0, startTime, currRawData, timeStamp);
|
||||
forecastRefEntry.compareAndSet(null,
|
||||
new ForecastRecord(oldForecast, 0.0, startTime));
|
||||
incorporateReading(timeStamp, rawData);
|
||||
incorporateReading(timeStamp, currRawData);
|
||||
return;
|
||||
}
|
||||
while (!forecastRefEntry.compareAndSet(oldRec, oldRec.append(timeStamp,
|
||||
rawData))) {
|
||||
currRawData))) {
|
||||
oldRec = forecastRefEntry.get();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public double getForecast() {
|
||||
|
@ -136,7 +173,7 @@ public class SimpleExponentialSmoothing {
|
|||
return DEFAULT_FORECAST;
|
||||
}
|
||||
|
||||
public boolean isDefaultForecast(double value) {
|
||||
public boolean isDefaultForecast(final double value) {
|
||||
return value == DEFAULT_FORECAST;
|
||||
}
|
||||
|
||||
|
@ -148,7 +185,7 @@ public class SimpleExponentialSmoothing {
|
|||
return DEFAULT_FORECAST;
|
||||
}
|
||||
|
||||
public boolean isErrorWithinBound(double bound) {
|
||||
public boolean isErrorWithinBound(final double bound) {
|
||||
double squaredErr = getSSE();
|
||||
if (squaredErr < 0) {
|
||||
return false;
|
||||
|
@ -185,8 +222,8 @@ public class SimpleExponentialSmoothing {
|
|||
String res = "NULL";
|
||||
ForecastRecord rec = forecastRefEntry.get();
|
||||
if (rec != null) {
|
||||
res = "rec.index = " + rec.myIndex + ", forecast t: " + rec.timeStamp +
|
||||
", forecast: " + rec.forecast
|
||||
res = "rec.index = " + rec.myIndex + ", forecast t: " + rec.timeStamp
|
||||
+ ", forecast: " + rec.forecast
|
||||
+ ", sample: " + rec.sample + ", raw: " + rec.rawData + ", error: "
|
||||
+ rec.sseError + ", alpha: " + rec.alpha;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one
|
||||
* or more contributor license agreements. See the NOTICE file
|
||||
* distributed with this work for additional information
|
||||
* regarding copyright ownership. The ASF licenses this file
|
||||
* to you under the Apache License, Version 2.0 (the
|
||||
* "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
@InterfaceAudience.Private
|
||||
package org.apache.hadoop.mapreduce.v2.app.speculate.forecast;
|
||||
import org.apache.hadoop.classification.InterfaceAudience;
|
|
@ -22,8 +22,11 @@ import java.io.File;
|
|||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.util.Arrays;
|
||||
import java.util.EnumSet;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.FileContext;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
|
@ -372,18 +375,45 @@ public class MRApp extends MRAppMaster {
|
|||
TaskAttemptReport report = attempt.getReport();
|
||||
while (!finalState.equals(report.getTaskAttemptState()) &&
|
||||
timeoutSecs++ < 20) {
|
||||
System.out.println("TaskAttempt State is : " + report.getTaskAttemptState() +
|
||||
" Waiting for state : " + finalState +
|
||||
" progress : " + report.getProgress());
|
||||
System.out.println(
|
||||
"TaskAttempt " + attempt.getID().toString() + " State is : "
|
||||
+ report.getTaskAttemptState()
|
||||
+ " Waiting for state : " + finalState
|
||||
+ " progress : " + report.getProgress());
|
||||
report = attempt.getReport();
|
||||
Thread.sleep(500);
|
||||
}
|
||||
System.out.println("TaskAttempt State is : " + report.getTaskAttemptState());
|
||||
System.out.println("TaskAttempt State is : "
|
||||
+ report.getTaskAttemptState());
|
||||
Assert.assertEquals("TaskAttempt state is not correct (timedout)",
|
||||
finalState,
|
||||
report.getTaskAttemptState());
|
||||
}
|
||||
|
||||
public void waitForState(TaskAttempt attempt,
|
||||
TaskAttemptState...finalStates) throws Exception {
|
||||
int timeoutSecs = 0;
|
||||
TaskAttemptReport report = attempt.getReport();
|
||||
List<TaskAttemptState> targetStates = Arrays.asList(finalStates);
|
||||
String statesValues = targetStates.stream().map(Object::toString).collect(
|
||||
Collectors.joining(","));
|
||||
while (!targetStates.contains(report.getTaskAttemptState()) &&
|
||||
timeoutSecs++ < 20) {
|
||||
System.out.println(
|
||||
"TaskAttempt " + attempt.getID().toString() + " State is : "
|
||||
+ report.getTaskAttemptState()
|
||||
+ " Waiting for states: " + statesValues
|
||||
+ ". curent state is : " + report.getTaskAttemptState()
|
||||
+ ". progress : " + report.getProgress());
|
||||
report = attempt.getReport();
|
||||
Thread.sleep(500);
|
||||
}
|
||||
System.out.println("TaskAttempt State is : "
|
||||
+ report.getTaskAttemptState());
|
||||
Assert.assertTrue("TaskAttempt state is not correct (timedout)",
|
||||
targetStates.contains(report.getTaskAttemptState()));
|
||||
}
|
||||
|
||||
public void waitForState(Task task, TaskState finalState) throws Exception {
|
||||
int timeoutSecs = 0;
|
||||
TaskReport report = task.getReport();
|
||||
|
|
|
@ -18,11 +18,14 @@
|
|||
|
||||
package org.apache.hadoop.mapreduce.v2;
|
||||
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Iterator;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import org.apache.hadoop.mapreduce.MRJobConfig;
|
||||
|
@ -50,19 +53,94 @@ import org.apache.hadoop.yarn.event.EventHandler;
|
|||
import org.apache.hadoop.yarn.util.Clock;
|
||||
import org.apache.hadoop.yarn.util.ControlledClock;
|
||||
import org.apache.hadoop.yarn.util.SystemClock;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
|
||||
import com.google.common.base.Supplier;
|
||||
import org.junit.rules.TestRule;
|
||||
import org.junit.runner.Description;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.junit.runners.model.Statement;
|
||||
|
||||
/**
|
||||
* The type Test speculative execution with mr app.
|
||||
* It test the speculation behavior given a list of estimator classes.
|
||||
*/
|
||||
@SuppressWarnings({ "unchecked", "rawtypes" })
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestSpeculativeExecutionWithMRApp {
|
||||
|
||||
/** Number of times to re-try the failing tests. */
|
||||
private static final int ASSERT_SPECULATIONS_COUNT_RETRIES = 3;
|
||||
private static final int NUM_MAPPERS = 5;
|
||||
private static final int NUM_REDUCERS = 0;
|
||||
|
||||
/**
|
||||
* Speculation has non-deterministic behavior due to racing and timing. Use
|
||||
* retry to verify that junit tests can pass.
|
||||
*/
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
public @interface Retry {}
|
||||
|
||||
/**
|
||||
* The type Retry rule.
|
||||
*/
|
||||
class RetryRule implements TestRule {
|
||||
|
||||
private AtomicInteger retryCount;
|
||||
|
||||
/**
|
||||
* Instantiates a new Retry rule.
|
||||
*
|
||||
* @param retries the retries
|
||||
*/
|
||||
RetryRule(int retries) {
|
||||
super();
|
||||
this.retryCount = new AtomicInteger(retries);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Statement apply(final Statement base,
|
||||
final Description description) {
|
||||
return new Statement() {
|
||||
@Override
|
||||
public void evaluate() throws Throwable {
|
||||
Throwable caughtThrowable = null;
|
||||
|
||||
while (retryCount.getAndDecrement() > 0) {
|
||||
try {
|
||||
base.evaluate();
|
||||
return;
|
||||
} catch (Throwable t) {
|
||||
if (retryCount.get() > 0 &&
|
||||
description.getAnnotation(Retry.class) != null) {
|
||||
caughtThrowable = t;
|
||||
System.out.println(
|
||||
description.getDisplayName() +
|
||||
": Failed, " +
|
||||
retryCount.toString() +
|
||||
" retries remain");
|
||||
} else {
|
||||
throw caughtThrowable;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The Rule.
|
||||
*/
|
||||
@Rule
|
||||
public RetryRule rule = new RetryRule(ASSERT_SPECULATIONS_COUNT_RETRIES);
|
||||
|
||||
/**
|
||||
* Gets test parameters.
|
||||
*
|
||||
* @return the test parameters
|
||||
*/
|
||||
@Parameterized.Parameters(name = "{index}: TaskEstimator(EstimatorClass {0})")
|
||||
public static Collection<Object[]> getTestParameters() {
|
||||
return Arrays.asList(new Object[][] {
|
||||
|
@ -73,12 +151,23 @@ public class TestSpeculativeExecutionWithMRApp {
|
|||
|
||||
private Class<? extends TaskRuntimeEstimator> estimatorClass;
|
||||
|
||||
/**
|
||||
* Instantiates a new Test speculative execution with mr app.
|
||||
*
|
||||
* @param estimatorKlass the estimator klass
|
||||
*/
|
||||
public TestSpeculativeExecutionWithMRApp(
|
||||
Class<? extends TaskRuntimeEstimator> estimatorKlass) {
|
||||
this.estimatorClass = estimatorKlass;
|
||||
}
|
||||
|
||||
@Test
|
||||
/**
|
||||
* Test speculate successful without update events.
|
||||
*
|
||||
* @throws Exception the exception
|
||||
*/
|
||||
@Retry
|
||||
@Test (timeout = 360000)
|
||||
public void testSpeculateSuccessfulWithoutUpdateEvents() throws Exception {
|
||||
|
||||
Clock actualClock = SystemClock.getInstance();
|
||||
|
@ -128,7 +217,8 @@ public class TestSpeculativeExecutionWithMRApp {
|
|||
TaskAttemptEventType.TA_DONE));
|
||||
appEventHandler.handle(new TaskAttemptEvent(taskAttempt.getKey(),
|
||||
TaskAttemptEventType.TA_CONTAINER_COMPLETED));
|
||||
app.waitForState(taskAttempt.getValue(), TaskAttemptState.SUCCEEDED);
|
||||
app.waitForState(taskAttempt.getValue(), TaskAttemptState.SUCCEEDED,
|
||||
TaskAttemptState.KILLED);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -150,8 +240,14 @@ public class TestSpeculativeExecutionWithMRApp {
|
|||
app.waitForState(Service.STATE.STOPPED);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSepculateSuccessfulWithUpdateEvents() throws Exception {
|
||||
/**
|
||||
* Test speculate successful with update events.
|
||||
*
|
||||
* @throws Exception the exception
|
||||
*/
|
||||
@Retry
|
||||
@Test (timeout = 360000)
|
||||
public void testSpeculateSuccessfulWithUpdateEvents() throws Exception {
|
||||
|
||||
Clock actualClock = SystemClock.getInstance();
|
||||
final ControlledClock clock = new ControlledClock(actualClock);
|
||||
|
@ -198,7 +294,8 @@ public class TestSpeculativeExecutionWithMRApp {
|
|||
appEventHandler.handle(new TaskAttemptEvent(taskAttempt.getKey(),
|
||||
TaskAttemptEventType.TA_CONTAINER_COMPLETED));
|
||||
numTasksToFinish--;
|
||||
app.waitForState(taskAttempt.getValue(), TaskAttemptState.SUCCEEDED);
|
||||
app.waitForState(taskAttempt.getValue(), TaskAttemptState.KILLED,
|
||||
TaskAttemptState.SUCCEEDED);
|
||||
} else {
|
||||
// The last task is chosen for speculation
|
||||
TaskAttemptStatus status =
|
||||
|
@ -214,13 +311,12 @@ public class TestSpeculativeExecutionWithMRApp {
|
|||
}
|
||||
|
||||
clock.setTime(System.currentTimeMillis() + 15000);
|
||||
// give a chance to the speculator thread to run a scan before we proceed
|
||||
// with updating events
|
||||
Thread.yield();
|
||||
|
||||
for (Map.Entry<TaskId, Task> task : tasks.entrySet()) {
|
||||
for (Map.Entry<TaskAttemptId, TaskAttempt> taskAttempt : task.getValue()
|
||||
.getAttempts().entrySet()) {
|
||||
if (taskAttempt.getValue().getState() != TaskAttemptState.SUCCEEDED) {
|
||||
if (!(taskAttempt.getValue().getState() == TaskAttemptState.SUCCEEDED
|
||||
|| taskAttempt.getValue().getState() == TaskAttemptState.KILLED)) {
|
||||
TaskAttemptStatus status =
|
||||
createTaskAttemptStatus(taskAttempt.getKey(), (float) 0.75,
|
||||
TaskAttemptState.RUNNING);
|
||||
|
|
Loading…
Reference in New Issue