MAPREDUCE-7252. Handling 0 progress in SimpleExponential task runtime estimator

Signed-off-by: Jonathan Eagles <jeagles@gmail.com>
This commit is contained in:
Ahmed Hussein 2020-01-08 11:08:13 -06:00 committed by Jonathan Eagles
parent 52cc20e9ea
commit cdd6efd3ab
6 changed files with 308 additions and 96 deletions

View File

@ -18,6 +18,11 @@
package org.apache.hadoop.mapreduce.v2.app.speculate;
public class DataStatistics {
/**
* factor used to calculate confidence interval within 95%.
*/
private static final double DEFAULT_CI_FACTOR = 1.96;
private int count = 0;
private double sum = 0;
private double sumSquares = 0;
@ -25,19 +30,20 @@ public class DataStatistics {
public DataStatistics() {
}
public DataStatistics(double initNum) {
public DataStatistics(final double initNum) {
this.count = 1;
this.sum = initNum;
this.sumSquares = initNum * initNum;
}
public synchronized void add(double newNum) {
public synchronized void add(final double newNum) {
this.count++;
this.sum += newNum;
this.sumSquares += newNum * newNum;
}
public synchronized void updateStatistics(double old, double update) {
public synchronized void updateStatistics(final double old,
final double update) {
this.sum += update - old;
this.sumSquares += (update * update) - (old * old);
}
@ -59,7 +65,7 @@ public synchronized double std() {
return Math.sqrt(this.var());
}
public synchronized double outlier(float sigma) {
public synchronized double outlier(final float sigma) {
if (count != 0.0) {
return mean() + std() * sigma;
}
@ -78,10 +84,12 @@ public synchronized double count() {
* @return the mean value adding 95% confidence interval
*/
public synchronized double meanCI() {
if (count <= 1) return 0.0;
if (count <= 1) {
return 0.0;
}
double currMean = mean();
double currStd = std();
return currMean + (1.96 * currStd / Math.sqrt(count));
return currMean + (DEFAULT_CI_FACTOR * currStd / Math.sqrt(count));
}
public String toString() {

View File

@ -33,7 +33,22 @@
* A task Runtime Estimator based on exponential smoothing.
*/
public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
private final static long DEFAULT_ESTIMATE_RUNTIME = -1L;
/**
* The default value returned by the estimator when no records exist.
*/
private static final long DEFAULT_ESTIMATE_RUNTIME = -1L;
/**
* Given a forecast of value 0.0, it is getting replaced by the default value
* to avoid division by 0.
*/
private static final double DEFAULT_PROGRESS_VALUE = 1E-10;
/**
* Factor used to calculate the confidence interval.
*/
private static final double CONFIDENCE_INTERVAL_FACTOR = 0.25;
/**
* Constant time used to calculate the smoothing exponential factor.
@ -53,11 +68,15 @@ public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
*/
private long stagnatedWindow;
/**
* A map of TA Id to the statistic model of smooth exponential.
*/
private final ConcurrentMap<TaskAttemptId,
AtomicReference<SimpleExponentialSmoothing>>
estimates = new ConcurrentHashMap<>();
private SimpleExponentialSmoothing getForecastEntry(TaskAttemptId attemptID) {
private SimpleExponentialSmoothing getForecastEntry(
final TaskAttemptId attemptID) {
AtomicReference<SimpleExponentialSmoothing> entryRef = estimates
.get(attemptID);
if (entryRef == null) {
@ -66,8 +85,8 @@ private SimpleExponentialSmoothing getForecastEntry(TaskAttemptId attemptID) {
return entryRef.get();
}
private void incorporateReading(TaskAttemptId attemptID,
float newRawData, long newTimeStamp) {
private void incorporateReading(final TaskAttemptId attemptID,
final float newRawData, final long newTimeStamp) {
SimpleExponentialSmoothing foreCastEntry = getForecastEntry(attemptID);
if (foreCastEntry == null) {
Long tStartTime = startTimes.get(attemptID);
@ -86,7 +105,8 @@ private void incorporateReading(TaskAttemptId attemptID,
}
@Override
public void contextualize(Configuration conf, AppContext context) {
public void contextualize(final Configuration conf,
final AppContext context) {
super.contextualize(conf, context);
constTime
@ -103,17 +123,15 @@ public void contextualize(Configuration conf, AppContext context) {
}
@Override
public long estimatedRuntime(TaskAttemptId id) {
public long estimatedRuntime(final TaskAttemptId id) {
SimpleExponentialSmoothing foreCastEntry = getForecastEntry(id);
if (foreCastEntry == null) {
return DEFAULT_ESTIMATE_RUNTIME;
}
// TODO: What should we do when estimate is zero
double remainingWork = Math.min(1.0, 1.0 - foreCastEntry.getRawData());
double forecast = foreCastEntry.getForecast();
if (forecast <= 0.0) {
return DEFAULT_ESTIMATE_RUNTIME;
}
double remainingWork = Math
.max(0.0, Math.min(1.0, 1.0 - foreCastEntry.getRawData()));
double forecast = Math
.max(DEFAULT_PROGRESS_VALUE, foreCastEntry.getForecast());
long remainingTime = (long) (remainingWork / forecast);
long estimatedRuntime = remainingTime
+ foreCastEntry.getTimeStamp()
@ -122,21 +140,23 @@ public long estimatedRuntime(TaskAttemptId id) {
}
@Override
public long estimatedNewAttemptRuntime(TaskId id) {
public long estimatedNewAttemptRuntime(final TaskId id) {
DataStatistics statistics = dataStatisticsForTask(id);
if (statistics == null) {
return -1L;
return DEFAULT_ESTIMATE_RUNTIME;
}
double statsMeanCI = statistics.meanCI();
double expectedVal =
statsMeanCI + Math.min(statsMeanCI * 0.25, statistics.std() / 2);
statsMeanCI + Math.min(statsMeanCI * CONFIDENCE_INTERVAL_FACTOR,
statistics.std() / 2);
return (long) (expectedVal);
}
@Override
public boolean hasStagnatedProgress(TaskAttemptId id, long timeStamp) {
public boolean hasStagnatedProgress(final TaskAttemptId id,
final long timeStamp) {
SimpleExponentialSmoothing foreCastEntry = getForecastEntry(id);
if (foreCastEntry == null) {
return false;
@ -145,7 +165,7 @@ public boolean hasStagnatedProgress(TaskAttemptId id, long timeStamp) {
}
@Override
public long runtimeEstimateVariance(TaskAttemptId id) {
public long runtimeEstimateVariance(final TaskAttemptId id) {
SimpleExponentialSmoothing forecastEntry = getForecastEntry(id);
if (forecastEntry == null) {
return DEFAULT_ESTIMATE_RUNTIME;
@ -154,12 +174,13 @@ public long runtimeEstimateVariance(TaskAttemptId id) {
if (forecastEntry.isDefaultForecast(forecast)) {
return DEFAULT_ESTIMATE_RUNTIME;
}
//TODO: What is the best way to measure variance in runtime
//TODO What is the best way to measure variance in runtime
return 0L;
}
@Override
public void updateAttempt(TaskAttemptStatus status, long timestamp) {
public void updateAttempt(final TaskAttemptStatus status,
final long timestamp) {
super.updateAttempt(status, timestamp);
TaskAttemptId attemptID = status.id;

View File

@ -24,108 +24,145 @@
* Implementation of the static model for Simple exponential smoothing.
*/
public class SimpleExponentialSmoothing {
public final static double DEFAULT_FORECAST = -1.0;
private static final double DEFAULT_FORECAST = -1.0;
private final int kMinimumReads;
private final long kStagnatedWindow;
private final long startTime;
private long timeConstant;
/**
* Holds reference to the current forecast record.
*/
private AtomicReference<ForecastRecord> forecastRefEntry;
public static SimpleExponentialSmoothing createForecast(long timeConstant,
int skipCnt, long stagnatedWindow, long timeStamp) {
public static SimpleExponentialSmoothing createForecast(
final long timeConstant,
final int skipCnt, final long stagnatedWindow, final long timeStamp) {
return new SimpleExponentialSmoothing(timeConstant, skipCnt,
stagnatedWindow, timeStamp);
}
SimpleExponentialSmoothing(long ktConstant, int skipCnt,
long stagnatedWindow, long timeStamp) {
kMinimumReads = skipCnt;
kStagnatedWindow = stagnatedWindow;
SimpleExponentialSmoothing(final long ktConstant, final int skipCnt,
final long stagnatedWindow, final long timeStamp) {
this.kMinimumReads = skipCnt;
this.kStagnatedWindow = stagnatedWindow;
this.timeConstant = ktConstant;
this.startTime = timeStamp;
this.forecastRefEntry = new AtomicReference<ForecastRecord>(null);
}
private class ForecastRecord {
private double alpha;
private long timeStamp;
private double sample;
private double rawData;
private final double alpha;
private final long timeStamp;
private final double sample;
private final double rawData;
private double forecast;
private double sseError;
private long myIndex;
private final double sseError;
private final long myIndex;
private ForecastRecord prevRec;
ForecastRecord(double forecast, double rawData, long timeStamp) {
this(0.0, forecast, rawData, forecast, timeStamp, 0.0, 0);
ForecastRecord(final double currForecast, final double currRawData,
final long currTimeStamp) {
this(0.0, currForecast, currRawData, currForecast, currTimeStamp, 0.0, 0);
}
ForecastRecord(double alpha, double sample, double rawData,
double forecast, long timeStamp, double accError, long index) {
this.timeStamp = timeStamp;
this.alpha = alpha;
this.sseError = 0.0;
this.sample = sample;
this.forecast = forecast;
this.rawData = rawData;
ForecastRecord(final double alphaVal, final double currSample,
final double currRawData,
final double currForecast, final long currTimeStamp,
final double accError,
final long index) {
this.timeStamp = currTimeStamp;
this.alpha = alphaVal;
this.sample = currSample;
this.forecast = currForecast;
this.rawData = currRawData;
this.sseError = accError;
this.myIndex = index;
}
private double preProcessRawData(double rData, long newTime) {
private ForecastRecord createForecastRecord(final double alphaVal,
final double currSample,
final double currRawData,
final double currForecast, final long currTimeStamp,
final double accError,
final long index,
final ForecastRecord prev) {
ForecastRecord forecastRec =
new ForecastRecord(alphaVal, currSample, currRawData, currForecast,
currTimeStamp, accError, index);
forecastRec.prevRec = prev;
return forecastRec;
}
private double preProcessRawData(final double rData, final long newTime) {
return processRawData(this.rawData, this.timeStamp, rData, newTime);
}
public ForecastRecord append(long newTimeStamp, double rData) {
if (this.timeStamp > newTimeStamp) {
public ForecastRecord append(final long newTimeStamp, final double rData) {
if (this.timeStamp >= newTimeStamp
&& Double.compare(this.rawData, rData) >= 0) {
// progress reported twice. Do nothing.
return this;
}
double newSample = preProcessRawData(rData, newTimeStamp);
ForecastRecord refRecord = this;
if (newTimeStamp == this.timeStamp) {
// we need to restore old value if possible
if (this.prevRec != null) {
refRecord = this.prevRec;
}
}
double newSample = refRecord.preProcessRawData(rData, newTimeStamp);
long deltaTime = this.timeStamp - newTimeStamp;
if (this.myIndex == kMinimumReads) {
if (refRecord.myIndex == kMinimumReads) {
timeConstant = Math.max(timeConstant, newTimeStamp - startTime);
}
double smoothFactor =
1 - Math.exp(((double) deltaTime) / timeConstant);
double forecastVal =
smoothFactor * newSample + (1.0 - smoothFactor) * this.forecast;
smoothFactor * newSample + (1.0 - smoothFactor) * refRecord.forecast;
double newSSEError =
this.sseError + Math.pow(newSample - this.forecast, 2);
return new ForecastRecord(smoothFactor, newSample, rData, forecastVal,
newTimeStamp, newSSEError, this.myIndex + 1);
refRecord.sseError + Math.pow(newSample - refRecord.forecast, 2);
return refRecord
.createForecastRecord(smoothFactor, newSample, rData, forecastVal,
newTimeStamp, newSSEError, refRecord.myIndex + 1, refRecord);
}
}
}
public boolean isDataStagnated(long timeStamp) {
/**
* checks if the task is hanging up.
* @param timeStamp current time of the scan.
* @return true if we have number of samples > kMinimumReads and the record
* timestamp has expired.
*/
public boolean isDataStagnated(final long timeStamp) {
ForecastRecord rec = forecastRefEntry.get();
if (rec != null && rec.myIndex <= kMinimumReads) {
return (rec.timeStamp + kStagnatedWindow) < timeStamp;
if (rec != null && rec.myIndex > kMinimumReads) {
return (rec.timeStamp + kStagnatedWindow) > timeStamp;
}
return false;
}
static double processRawData(double oldRawData, long oldTime,
double newRawData, long newTime) {
static double processRawData(final double oldRawData, final long oldTime,
final double newRawData, final long newTime) {
double rate = (newRawData - oldRawData) / (newTime - oldTime);
return rate;
}
public void incorporateReading(long timeStamp, double rawData) {
public void incorporateReading(final long timeStamp,
final double currRawData) {
ForecastRecord oldRec = forecastRefEntry.get();
if (oldRec == null) {
double oldForecast =
processRawData(0, startTime, rawData, timeStamp);
processRawData(0, startTime, currRawData, timeStamp);
forecastRefEntry.compareAndSet(null,
new ForecastRecord(oldForecast, 0.0, startTime));
incorporateReading(timeStamp, rawData);
incorporateReading(timeStamp, currRawData);
return;
}
while (!forecastRefEntry.compareAndSet(oldRec, oldRec.append(timeStamp,
rawData))) {
currRawData))) {
oldRec = forecastRefEntry.get();
}
}
public double getForecast() {
@ -136,7 +173,7 @@ public double getForecast() {
return DEFAULT_FORECAST;
}
public boolean isDefaultForecast(double value) {
public boolean isDefaultForecast(final double value) {
return value == DEFAULT_FORECAST;
}
@ -148,7 +185,7 @@ public double getSSE() {
return DEFAULT_FORECAST;
}
public boolean isErrorWithinBound(double bound) {
public boolean isErrorWithinBound(final double bound) {
double squaredErr = getSSE();
if (squaredErr < 0) {
return false;
@ -185,8 +222,8 @@ public String toString() {
String res = "NULL";
ForecastRecord rec = forecastRefEntry.get();
if (rec != null) {
res = "rec.index = " + rec.myIndex + ", forecast t: " + rec.timeStamp +
", forecast: " + rec.forecast
res = "rec.index = " + rec.myIndex + ", forecast t: " + rec.timeStamp
+ ", forecast: " + rec.forecast
+ ", sample: " + rec.sample + ", raw: " + rec.rawData + ", error: "
+ rec.sseError + ", alpha: " + rec.alpha;
}

View File

@ -0,0 +1,20 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
@InterfaceAudience.Private
package org.apache.hadoop.mapreduce.v2.app.speculate.forecast;
import org.apache.hadoop.classification.InterfaceAudience;

View File

@ -22,8 +22,11 @@
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileContext;
import org.apache.hadoop.fs.FileSystem;
@ -372,18 +375,45 @@ public void waitForState(TaskAttempt attempt,
TaskAttemptReport report = attempt.getReport();
while (!finalState.equals(report.getTaskAttemptState()) &&
timeoutSecs++ < 20) {
System.out.println("TaskAttempt State is : " + report.getTaskAttemptState() +
" Waiting for state : " + finalState +
" progress : " + report.getProgress());
System.out.println(
"TaskAttempt " + attempt.getID().toString() + " State is : "
+ report.getTaskAttemptState()
+ " Waiting for state : " + finalState
+ " progress : " + report.getProgress());
report = attempt.getReport();
Thread.sleep(500);
}
System.out.println("TaskAttempt State is : " + report.getTaskAttemptState());
System.out.println("TaskAttempt State is : "
+ report.getTaskAttemptState());
Assert.assertEquals("TaskAttempt state is not correct (timedout)",
finalState,
report.getTaskAttemptState());
}
public void waitForState(TaskAttempt attempt,
TaskAttemptState...finalStates) throws Exception {
int timeoutSecs = 0;
TaskAttemptReport report = attempt.getReport();
List<TaskAttemptState> targetStates = Arrays.asList(finalStates);
String statesValues = targetStates.stream().map(Object::toString).collect(
Collectors.joining(","));
while (!targetStates.contains(report.getTaskAttemptState()) &&
timeoutSecs++ < 20) {
System.out.println(
"TaskAttempt " + attempt.getID().toString() + " State is : "
+ report.getTaskAttemptState()
+ " Waiting for states: " + statesValues
+ ". curent state is : " + report.getTaskAttemptState()
+ ". progress : " + report.getProgress());
report = attempt.getReport();
Thread.sleep(500);
}
System.out.println("TaskAttempt State is : "
+ report.getTaskAttemptState());
Assert.assertTrue("TaskAttempt state is not correct (timedout)",
targetStates.contains(report.getTaskAttemptState()));
}
public void waitForState(Task task, TaskState finalState) throws Exception {
int timeoutSecs = 0;
TaskReport report = task.getReport();

View File

@ -18,11 +18,14 @@
package org.apache.hadoop.mapreduce.v2;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.hadoop.mapreduce.MRJobConfig;
@ -50,19 +53,94 @@
import org.apache.hadoop.yarn.util.Clock;
import org.apache.hadoop.yarn.util.ControlledClock;
import org.apache.hadoop.yarn.util.SystemClock;
import org.junit.Rule;
import org.junit.Test;
import com.google.common.base.Supplier;
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.model.Statement;
/**
* The type Test speculative execution with mr app.
* It test the speculation behavior given a list of estimator classes.
*/
@SuppressWarnings({ "unchecked", "rawtypes" })
@RunWith(Parameterized.class)
public class TestSpeculativeExecutionWithMRApp {
/** Number of times to re-try the failing tests. */
private static final int ASSERT_SPECULATIONS_COUNT_RETRIES = 3;
private static final int NUM_MAPPERS = 5;
private static final int NUM_REDUCERS = 0;
/**
* Speculation has non-deterministic behavior due to racing and timing. Use
* retry to verify that junit tests can pass.
*/
@Retention(RetentionPolicy.RUNTIME)
public @interface Retry {}
/**
* The type Retry rule.
*/
class RetryRule implements TestRule {
private AtomicInteger retryCount;
/**
* Instantiates a new Retry rule.
*
* @param retries the retries
*/
RetryRule(int retries) {
super();
this.retryCount = new AtomicInteger(retries);
}
@Override
public Statement apply(final Statement base,
final Description description) {
return new Statement() {
@Override
public void evaluate() throws Throwable {
Throwable caughtThrowable = null;
while (retryCount.getAndDecrement() > 0) {
try {
base.evaluate();
return;
} catch (Throwable t) {
if (retryCount.get() > 0 &&
description.getAnnotation(Retry.class) != null) {
caughtThrowable = t;
System.out.println(
description.getDisplayName() +
": Failed, " +
retryCount.toString() +
" retries remain");
} else {
throw caughtThrowable;
}
}
}
}
};
}
}
/**
* The Rule.
*/
@Rule
public RetryRule rule = new RetryRule(ASSERT_SPECULATIONS_COUNT_RETRIES);
/**
* Gets test parameters.
*
* @return the test parameters
*/
@Parameterized.Parameters(name = "{index}: TaskEstimator(EstimatorClass {0})")
public static Collection<Object[]> getTestParameters() {
return Arrays.asList(new Object[][] {
@ -73,12 +151,23 @@ public static Collection<Object[]> getTestParameters() {
private Class<? extends TaskRuntimeEstimator> estimatorClass;
/**
* Instantiates a new Test speculative execution with mr app.
*
* @param estimatorKlass the estimator klass
*/
public TestSpeculativeExecutionWithMRApp(
Class<? extends TaskRuntimeEstimator> estimatorKlass) {
this.estimatorClass = estimatorKlass;
}
@Test
/**
* Test speculate successful without update events.
*
* @throws Exception the exception
*/
@Retry
@Test (timeout = 360000)
public void testSpeculateSuccessfulWithoutUpdateEvents() throws Exception {
Clock actualClock = SystemClock.getInstance();
@ -128,7 +217,8 @@ public void testSpeculateSuccessfulWithoutUpdateEvents() throws Exception {
TaskAttemptEventType.TA_DONE));
appEventHandler.handle(new TaskAttemptEvent(taskAttempt.getKey(),
TaskAttemptEventType.TA_CONTAINER_COMPLETED));
app.waitForState(taskAttempt.getValue(), TaskAttemptState.SUCCEEDED);
app.waitForState(taskAttempt.getValue(), TaskAttemptState.SUCCEEDED,
TaskAttemptState.KILLED);
}
}
}
@ -150,8 +240,14 @@ public Boolean get() {
app.waitForState(Service.STATE.STOPPED);
}
@Test
public void testSepculateSuccessfulWithUpdateEvents() throws Exception {
/**
* Test speculate successful with update events.
*
* @throws Exception the exception
*/
@Retry
@Test (timeout = 360000)
public void testSpeculateSuccessfulWithUpdateEvents() throws Exception {
Clock actualClock = SystemClock.getInstance();
final ControlledClock clock = new ControlledClock(actualClock);
@ -198,7 +294,8 @@ public void testSepculateSuccessfulWithUpdateEvents() throws Exception {
appEventHandler.handle(new TaskAttemptEvent(taskAttempt.getKey(),
TaskAttemptEventType.TA_CONTAINER_COMPLETED));
numTasksToFinish--;
app.waitForState(taskAttempt.getValue(), TaskAttemptState.SUCCEEDED);
app.waitForState(taskAttempt.getValue(), TaskAttemptState.KILLED,
TaskAttemptState.SUCCEEDED);
} else {
// The last task is chosen for speculation
TaskAttemptStatus status =
@ -214,13 +311,12 @@ public void testSepculateSuccessfulWithUpdateEvents() throws Exception {
}
clock.setTime(System.currentTimeMillis() + 15000);
// give a chance to the speculator thread to run a scan before we proceed
// with updating events
Thread.yield();
for (Map.Entry<TaskId, Task> task : tasks.entrySet()) {
for (Map.Entry<TaskAttemptId, TaskAttempt> taskAttempt : task.getValue()
.getAttempts().entrySet()) {
if (taskAttempt.getValue().getState() != TaskAttemptState.SUCCEEDED) {
if (!(taskAttempt.getValue().getState() == TaskAttemptState.SUCCEEDED
|| taskAttempt.getValue().getState() == TaskAttemptState.KILLED)) {
TaskAttemptStatus status =
createTaskAttemptStatus(taskAttempt.getKey(), (float) 0.75,
TaskAttemptState.RUNNING);