[ML] rewriting stats gathering to use callbacks instead of a latch (#41793) (#41804)

This commit is contained in:
Benjamin Trent 2019-05-03 18:18:27 -05:00 committed by GitHub
parent c7924014fa
commit b69e28177b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 100 additions and 91 deletions

View File

@ -20,6 +20,7 @@ import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.tasks.TransportTasksAction; import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.inject.Inject;
@ -99,6 +100,8 @@ public class TransportGetDataFrameTransformsStatsAction extends
@Override @Override
protected void taskOperation(Request request, DataFrameTransformTask task, ActionListener<Response> listener) { protected void taskOperation(Request request, DataFrameTransformTask task, ActionListener<Response> listener) {
// Little extra insurance, make sure we only return transforms that aren't cancelled // Little extra insurance, make sure we only return transforms that aren't cancelled
ClusterState state = clusterService.state();
String nodeId = state.nodes().getLocalNode().getId();
if (task.isCancelled() == false) { if (task.isCancelled() == false) {
transformsCheckpointService.getCheckpointStats(task.getTransformId(), task.getCheckpoint(), task.getInProgressCheckpoint(), transformsCheckpointService.getCheckpointStats(task.getTransformId(), task.getCheckpoint(), task.getInProgressCheckpoint(),
ActionListener.wrap(checkpointStats -> { ActionListener.wrap(checkpointStats -> {
@ -109,7 +112,7 @@ public class TransportGetDataFrameTransformsStatsAction extends
Collections.singletonList(new DataFrameTransformStateAndStats(task.getTransformId(), task.getState(), Collections.singletonList(new DataFrameTransformStateAndStats(task.getTransformId(), task.getState(),
task.getStats(), DataFrameTransformCheckpointingInfo.EMPTY)), task.getStats(), DataFrameTransformCheckpointingInfo.EMPTY)),
Collections.emptyList(), Collections.emptyList(),
Collections.singletonList(new FailedNodeException("", "Failed to retrieve checkpointing info", e)))); Collections.singletonList(new FailedNodeException(nodeId, "Failed to retrieve checkpointing info", e))));
})); }));
} else { } else {
listener.onResponse(new Response(Collections.emptyList())); listener.onResponse(new Response(Collections.emptyList()));

View File

@ -9,7 +9,6 @@ package org.elasticsearch.xpack.dataframe.checkpoint;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.LatchedActionListener;
import org.elasticsearch.action.admin.indices.get.GetIndexAction; import org.elasticsearch.action.admin.indices.get.GetIndexAction;
import org.elasticsearch.action.admin.indices.get.GetIndexRequest; import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
import org.elasticsearch.action.admin.indices.stats.IndicesStatsAction; import org.elasticsearch.action.admin.indices.stats.IndicesStatsAction;
@ -28,8 +27,6 @@ import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
/** /**
* DataFrameTransform Checkpoint Service * DataFrameTransform Checkpoint Service
@ -41,17 +38,22 @@ import java.util.concurrent.TimeUnit;
*/ */
public class DataFrameTransformsCheckpointService { public class DataFrameTransformsCheckpointService {
private class Checkpoints { private static class Checkpoints {
DataFrameTransformCheckpoint currentCheckpoint = DataFrameTransformCheckpoint.EMPTY; DataFrameTransformCheckpoint currentCheckpoint = DataFrameTransformCheckpoint.EMPTY;
DataFrameTransformCheckpoint inProgressCheckpoint = DataFrameTransformCheckpoint.EMPTY; DataFrameTransformCheckpoint inProgressCheckpoint = DataFrameTransformCheckpoint.EMPTY;
DataFrameTransformCheckpoint sourceCheckpoint = DataFrameTransformCheckpoint.EMPTY; DataFrameTransformCheckpoint sourceCheckpoint = DataFrameTransformCheckpoint.EMPTY;
DataFrameTransformCheckpointingInfo buildInfo() {
return new DataFrameTransformCheckpointingInfo(
new DataFrameTransformCheckpointStats(currentCheckpoint.getTimestamp(), currentCheckpoint.getTimeUpperBound()),
new DataFrameTransformCheckpointStats(inProgressCheckpoint.getTimestamp(), inProgressCheckpoint.getTimeUpperBound()),
DataFrameTransformCheckpoint.getBehind(currentCheckpoint, sourceCheckpoint));
}
} }
private static final Logger logger = LogManager.getLogger(DataFrameTransformsCheckpointService.class); private static final Logger logger = LogManager.getLogger(DataFrameTransformsCheckpointService.class);
// timeout for retrieving checkpoint information
private static final int CHECKPOINT_STATS_TIMEOUT_SECONDS = 5;
private final Client client; private final Client client;
private final DataFrameTransformsConfigManager dataFrameTransformsConfigManager; private final DataFrameTransformsConfigManager dataFrameTransformsConfigManager;
@ -86,40 +88,49 @@ public class DataFrameTransformsCheckpointService {
long timeUpperBound = 0; long timeUpperBound = 0;
// 1st get index to see the indexes the user has access to // 1st get index to see the indexes the user has access to
GetIndexRequest getIndexRequest = new GetIndexRequest().indices(transformConfig.getSource().getIndex()); GetIndexRequest getIndexRequest = new GetIndexRequest()
.indices(transformConfig.getSource().getIndex())
.features(new GetIndexRequest.Feature[0]);
ClientHelper.executeWithHeadersAsync(transformConfig.getHeaders(), ClientHelper.DATA_FRAME_ORIGIN, client, GetIndexAction.INSTANCE, ClientHelper.executeWithHeadersAsync(transformConfig.getHeaders(), ClientHelper.DATA_FRAME_ORIGIN, client, GetIndexAction.INSTANCE,
getIndexRequest, ActionListener.wrap(getIndexResponse -> { getIndexRequest, ActionListener.wrap(getIndexResponse -> {
Set<String> userIndices = new HashSet<>(Arrays.asList(getIndexResponse.getIndices())); Set<String> userIndices = new HashSet<>(Arrays.asList(getIndexResponse.getIndices()));
// 2nd get stats request // 2nd get stats request
ClientHelper.executeAsyncWithOrigin(client, ClientHelper.DATA_FRAME_ORIGIN, IndicesStatsAction.INSTANCE, ClientHelper.executeAsyncWithOrigin(client,
new IndicesStatsRequest().indices(transformConfig.getSource().getIndex()), ActionListener.wrap(response -> { ClientHelper.DATA_FRAME_ORIGIN,
IndicesStatsAction.INSTANCE,
new IndicesStatsRequest()
.indices(transformConfig.getSource().getIndex())
.clear(),
ActionListener.wrap(
response -> {
if (response.getFailedShards() != 0) { if (response.getFailedShards() != 0) {
throw new CheckpointException("Source has [" + response.getFailedShards() + "] failed shards"); listener.onFailure(
new CheckpointException("Source has [" + response.getFailedShards() + "] failed shards"));
return;
} }
try {
Map<String, long[]> checkpointsByIndex = extractIndexCheckPoints(response.getShards(), userIndices); Map<String, long[]> checkpointsByIndex = extractIndexCheckPoints(response.getShards(), userIndices);
DataFrameTransformCheckpoint checkpointDoc = new DataFrameTransformCheckpoint(transformConfig.getId(), listener.onResponse(new DataFrameTransformCheckpoint(transformConfig.getId(),
timestamp, checkpoint, checkpointsByIndex, timeUpperBound); timestamp,
checkpoint,
listener.onResponse(checkpointDoc); checkpointsByIndex,
timeUpperBound));
}, IndicesStatsRequestException -> { } catch (CheckpointException checkpointException) {
throw new CheckpointException("Failed to retrieve indices stats", IndicesStatsRequestException); listener.onFailure(checkpointException);
})); }
},
}, getIndexException -> { listener::onFailure
throw new CheckpointException("Failed to retrieve list of indices", getIndexException); ));
})); },
listener::onFailure
));
} }
/** /**
* Get checkpointing stats for a data frame * Get checkpointing stats for a data frame
* *
* Implementation details:
* - fires up to 3 requests _in parallel_ rather than cascading them
* *
* @param transformId The data frame task * @param transformId The data frame task
* @param currentCheckpoint the current checkpoint * @param currentCheckpoint the current checkpoint
@ -132,71 +143,66 @@ public class DataFrameTransformsCheckpointService {
long inProgressCheckpoint, long inProgressCheckpoint,
ActionListener<DataFrameTransformCheckpointingInfo> listener) { ActionListener<DataFrameTransformCheckpointingInfo> listener) {
// process in parallel: current checkpoint, in-progress checkpoint, current state of the source
CountDownLatch latch = new CountDownLatch(3);
// ensure listener is called exactly once
final ActionListener<DataFrameTransformCheckpointingInfo> wrappedListener = ActionListener.notifyOnce(listener);
// holder structure for writing the results of the 3 parallel tasks
Checkpoints checkpoints = new Checkpoints(); Checkpoints checkpoints = new Checkpoints();
// get the current checkpoint // <3> notify the user once we have the current checkpoint
ActionListener<DataFrameTransformCheckpoint> currentCheckpointListener = ActionListener.wrap(
currentCheckpointObj -> {
checkpoints.currentCheckpoint = currentCheckpointObj;
listener.onResponse(checkpoints.buildInfo());
},
e -> {
logger.debug("Failed to retrieve current checkpoint [" +
currentCheckpoint + "] for data frame [" + transformId + "]", e);
listener.onFailure(new CheckpointException("Failure during current checkpoint info retrieval", e));
}
);
// <2> after the in progress checkpoint, get the current checkpoint
ActionListener<DataFrameTransformCheckpoint> inProgressCheckpointListener = ActionListener.wrap(
inProgressCheckpointObj -> {
checkpoints.inProgressCheckpoint = inProgressCheckpointObj;
if (currentCheckpoint != 0) { if (currentCheckpoint != 0) {
dataFrameTransformsConfigManager.getTransformCheckpoint(transformId, currentCheckpoint, dataFrameTransformsConfigManager.getTransformCheckpoint(transformId,
new LatchedActionListener<>(ActionListener.wrap(checkpoint -> checkpoints.currentCheckpoint = checkpoint, e -> { currentCheckpoint,
logger.debug("Failed to retrieve checkpoint [" + currentCheckpoint + "] for data frame []" + transformId, e); currentCheckpointListener);
wrappedListener
.onFailure(new CheckpointException("Failed to retrieve current checkpoint [" + currentCheckpoint + "]", e));
}), latch));
} else { } else {
latch.countDown(); currentCheckpointListener.onResponse(DataFrameTransformCheckpoint.EMPTY);
} }
},
e -> {
logger.debug("Failed to retrieve in progress checkpoint [" +
inProgressCheckpoint + "] for data frame [" + transformId + "]", e);
listener.onFailure(new CheckpointException("Failure during in progress checkpoint info retrieval", e));
}
);
// get the in-progress checkpoint // <1> after the source checkpoint, get the in progress checkpoint
ActionListener<DataFrameTransformCheckpoint> sourceCheckpointListener = ActionListener.wrap(
sourceCheckpoint -> {
checkpoints.sourceCheckpoint = sourceCheckpoint;
if (inProgressCheckpoint != 0) { if (inProgressCheckpoint != 0) {
dataFrameTransformsConfigManager.getTransformCheckpoint(transformId, inProgressCheckpoint, dataFrameTransformsConfigManager.getTransformCheckpoint(transformId,
new LatchedActionListener<>(ActionListener.wrap(checkpoint -> checkpoints.inProgressCheckpoint = checkpoint, e -> { inProgressCheckpoint,
logger.debug("Failed to retrieve in progress checkpoint [" + inProgressCheckpoint + "] for data frame [" inProgressCheckpointListener);
+ transformId + "]", e);
wrappedListener.onFailure(
new CheckpointException("Failed to retrieve in progress checkpoint [" + inProgressCheckpoint + "]", e));
}), latch));
} else { } else {
latch.countDown(); inProgressCheckpointListener.onResponse(DataFrameTransformCheckpoint.EMPTY);
} }
},
e -> {
logger.debug("Failed to retrieve source checkpoint for data frame [" + transformId + "]", e);
listener.onFailure(new CheckpointException("Failure during source checkpoint info retrieval", e));
}
);
// get the current state // <0> get the transform and the source, transient checkpoint
dataFrameTransformsConfigManager.getTransformConfiguration(transformId, ActionListener.wrap(transformConfig -> { dataFrameTransformsConfigManager.getTransformConfiguration(transformId, ActionListener.wrap(
getCheckpoint(transformConfig, transformConfig -> getCheckpoint(transformConfig, sourceCheckpointListener),
new LatchedActionListener<>(ActionListener.wrap(checkpoint -> checkpoints.sourceCheckpoint = checkpoint, e2 -> { transformError -> {
logger.debug("Failed to retrieve actual checkpoint for data frame [" + transformId + "]", e2); logger.warn("Failed to retrieve configuration for data frame [" + transformId + "]", transformError);
wrappedListener.onFailure(new CheckpointException("Failed to retrieve actual checkpoint", e2)); listener.onFailure(new CheckpointException("Failed to retrieve configuration", transformError));
}), latch)); })
}, e -> { );
logger.warn("Failed to retrieve configuration for data frame [" + transformId + "]", e);
wrappedListener.onFailure(new CheckpointException("Failed to retrieve configuration", e));
latch.countDown();
}));
try {
if (latch.await(CHECKPOINT_STATS_TIMEOUT_SECONDS, TimeUnit.SECONDS)) {
logger.debug("Retrieval of checkpoint information succeeded for data frame [" + transformId + "]");
wrappedListener.onResponse(new DataFrameTransformCheckpointingInfo(
new DataFrameTransformCheckpointStats(checkpoints.currentCheckpoint.getTimestamp(),
checkpoints.currentCheckpoint.getTimeUpperBound()),
new DataFrameTransformCheckpointStats(checkpoints.inProgressCheckpoint.getTimestamp(),
checkpoints.inProgressCheckpoint.getTimeUpperBound()),
DataFrameTransformCheckpoint.getBehind(checkpoints.currentCheckpoint, checkpoints.sourceCheckpoint)));
} else {
// timed out
logger.debug("Retrieval of checkpoint information has timed out for data frame [" + transformId + "]");
wrappedListener.onFailure(new CheckpointException("Retrieval of checkpoint information has timed out"));
}
} catch (InterruptedException e) {
logger.debug("Failed to retrieve checkpoints for data frame [" + transformId + "]", e);
wrappedListener.onFailure(new CheckpointException("Failure during checkpoint info retrieval", e));
}
} }
static Map<String, long[]> extractIndexCheckPoints(ShardStats[] shards, Set<String> userIndices) { static Map<String, long[]> extractIndexCheckPoints(ShardStats[] shards, Set<String> userIndices) {