[7.x][ML] Allow analytics process define its own progress phases (#55763) (#55791)

This is a continuation from #55580.

Now that we're parsing phase progresses from the analytics process
we change `ProgressTracker` to allow for custom phases between
the `loading_data` and `writing_results` phases. Each `DataFrameAnalysis`
may declare its own phases.

This commit sets things in place for the analytics process to start
reporting different phases per analysis type. However, this is
still preserving existing behaviour as all analyses currently
declare a single `analyzing` phase.

Backport of #55763
This commit is contained in:
Dimitris Athanasiou 2020-04-27 13:30:05 +03:00 committed by GitHub
parent fe9904fbea
commit 7f100c1196
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 318 additions and 86 deletions

View File

@ -348,6 +348,11 @@ public class Classification implements DataFrameAnalysis {
return jobId + STATE_DOC_ID_SUFFIX;
}
@Override
public List<String> getProgressPhases() {
return Collections.singletonList("analyzing");
}
public static String extractJobIdFromStateDoc(String stateDocId) {
int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_SUFFIX);
return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex);

View File

@ -66,6 +66,11 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
*/
String getStateDocId(String jobId);
/**
* Returns the progress phases the analysis goes through in order
*/
List<String> getProgressPhases();
/**
* Summarizes information about the fields that is necessary for analysis to generate
* the parameters needed for the process configuration.

View File

@ -249,6 +249,11 @@ public class OutlierDetection implements DataFrameAnalysis {
throw new UnsupportedOperationException("Outlier detection does not support state");
}
@Override
public List<String> getProgressPhases() {
return Collections.singletonList("analyzing");
}
public enum Method {
LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN;

View File

@ -213,6 +213,11 @@ public class Regression implements DataFrameAnalysis {
return jobId + STATE_DOC_ID_SUFFIX;
}
@Override
public List<String> getProgressPhases() {
return Collections.singletonList("analyzing");
}
public static String extractJobIdFromStateDoc(String stateDocId) {
int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_SUFFIX);
return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex);

View File

@ -44,9 +44,9 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
@ -55,7 +55,6 @@ import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
import org.elasticsearch.xpack.ml.dataframe.StoredProgress;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.utils.persistence.MlParserUtils;
import java.util.ArrayList;
@ -105,25 +104,20 @@ public class TransportGetDataFrameAnalyticsStatsAction
ActionListener<QueryPage<Stats>> listener) {
logger.debug("Get stats for running task [{}]", task.getParams().getId());
ActionListener<StatsHolder> statsHolderListener = ActionListener.wrap(
statsHolder -> {
ActionListener<Void> reindexingProgressListener = ActionListener.wrap(
aVoid -> {
Stats stats = buildStats(
task.getParams().getId(),
statsHolder.getProgressTracker().report(),
statsHolder.getDataCountsTracker().report(task.getParams().getId()),
statsHolder.getMemoryUsage(),
statsHolder.getAnalysisStats()
task.getStatsHolder().getProgressTracker().report(),
task.getStatsHolder().getDataCountsTracker().report(task.getParams().getId()),
task.getStatsHolder().getMemoryUsage(),
task.getStatsHolder().getAnalysisStats()
);
listener.onResponse(new QueryPage<>(Collections.singletonList(stats), 1,
GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));
}, listener::onFailure
);
ActionListener<Void> reindexingProgressListener = ActionListener.wrap(
aVoid -> statsHolderListener.onResponse(task.getStatsHolder()),
listener::onFailure
);
task.updateReindexTaskProgress(reindexingProgressListener);
}
@ -138,7 +132,7 @@ public class TransportGetDataFrameAnalyticsStatsAction
.collect(Collectors.toList());
request.setExpandedIds(expandedIds);
ActionListener<GetDataFrameAnalyticsStatsAction.Response> runningTasksStatsListener = ActionListener.wrap(
runningTasksStatsResponse -> gatherStatsForStoppedTasks(request.getExpandedIds(), runningTasksStatsResponse,
runningTasksStatsResponse -> gatherStatsForStoppedTasks(getResponse.getResources().results(), runningTasksStatsResponse,
ActionListener.wrap(
finalResponse -> {
@ -163,20 +157,20 @@ public class TransportGetDataFrameAnalyticsStatsAction
executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsAction.INSTANCE, getRequest, getResponseListener);
}
void gatherStatsForStoppedTasks(List<String> expandedIds, GetDataFrameAnalyticsStatsAction.Response runningTasksResponse,
void gatherStatsForStoppedTasks(List<DataFrameAnalyticsConfig> configs, GetDataFrameAnalyticsStatsAction.Response runningTasksResponse,
ActionListener<GetDataFrameAnalyticsStatsAction.Response> listener) {
List<String> stoppedTasksIds = determineStoppedTasksIds(expandedIds, runningTasksResponse.getResponse().results());
if (stoppedTasksIds.isEmpty()) {
List<DataFrameAnalyticsConfig> stoppedConfigs = determineStoppedConfigs(configs, runningTasksResponse.getResponse().results());
if (stoppedConfigs.isEmpty()) {
listener.onResponse(runningTasksResponse);
return;
}
AtomicInteger counter = new AtomicInteger(stoppedTasksIds.size());
AtomicArray<Stats> jobStats = new AtomicArray<>(stoppedTasksIds.size());
for (int i = 0; i < stoppedTasksIds.size(); i++) {
AtomicInteger counter = new AtomicInteger(stoppedConfigs.size());
AtomicArray<Stats> jobStats = new AtomicArray<>(stoppedConfigs.size());
for (int i = 0; i < stoppedConfigs.size(); i++) {
final int slot = i;
String jobId = stoppedTasksIds.get(i);
searchStats(jobId, ActionListener.wrap(
DataFrameAnalyticsConfig config = stoppedConfigs.get(i);
searchStats(config, ActionListener.wrap(
stats -> {
jobStats.set(slot, stats);
if (counter.decrementAndGet() == 0) {
@ -192,21 +186,24 @@ public class TransportGetDataFrameAnalyticsStatsAction
}
}
static List<String> determineStoppedTasksIds(List<String> expandedIds, List<Stats> runningTasksStats) {
static List<DataFrameAnalyticsConfig> determineStoppedConfigs(List<DataFrameAnalyticsConfig> configs, List<Stats> runningTasksStats) {
Set<String> startedTasksIds = runningTasksStats.stream().map(Stats::getId).collect(Collectors.toSet());
return expandedIds.stream().filter(id -> startedTasksIds.contains(id) == false).collect(Collectors.toList());
return configs.stream().filter(config -> startedTasksIds.contains(config.getId()) == false).collect(Collectors.toList());
}
private void searchStats(String configId, ActionListener<Stats> listener) {
RetrievedStatsHolder retrievedStatsHolder = new RetrievedStatsHolder();
private void searchStats(DataFrameAnalyticsConfig config, ActionListener<Stats> listener) {
logger.debug("[{}] Gathering stats for stopped task", config.getId());
RetrievedStatsHolder retrievedStatsHolder = new RetrievedStatsHolder(
ProgressTracker.fromZeroes(config.getAnalysis().getProgressPhases()).report());
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
multiSearchRequest.add(buildStoredProgressSearch(configId));
multiSearchRequest.add(buildStatsDocSearch(configId, DataCounts.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(configId, MemoryUsage.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(configId, OutlierDetectionStats.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(configId, ClassificationStats.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(configId, RegressionStats.TYPE_VALUE));
multiSearchRequest.add(buildStoredProgressSearch(config.getId()));
multiSearchRequest.add(buildStatsDocSearch(config.getId(), DataCounts.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(config.getId(), MemoryUsage.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(config.getId(), OutlierDetectionStats.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(config.getId(), ClassificationStats.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(config.getId(), RegressionStats.TYPE_VALUE));
executeAsyncWithOrigin(client, ML_ORIGIN, MultiSearchAction.INSTANCE, multiSearchRequest, ActionListener.wrap(
multiSearchResponse -> {
@ -218,7 +215,7 @@ public class TransportGetDataFrameAnalyticsStatsAction
logger.error(
new ParameterizedMessage(
"[{}] Item failure encountered during multi search for request [indices={}, source={}]: {}",
configId, itemRequest.indices(), itemRequest.source(), itemResponse.getFailureMessage()),
config.getId(), itemRequest.indices(), itemRequest.source(), itemResponse.getFailureMessage()),
itemResponse.getFailure());
listener.onFailure(ExceptionsHelper.serverError(itemResponse.getFailureMessage(), itemResponse.getFailure()));
return;
@ -227,13 +224,13 @@ public class TransportGetDataFrameAnalyticsStatsAction
if (hits.length == 0) {
// Not found
} else if (hits.length == 1) {
parseHit(hits[0], configId, retrievedStatsHolder);
parseHit(hits[0], config.getId(), retrievedStatsHolder);
} else {
throw ExceptionsHelper.serverError("Found [" + hits.length + "] hits when just one was requested");
}
}
}
listener.onResponse(buildStats(configId,
listener.onResponse(buildStats(config.getId(),
retrievedStatsHolder.progress.get(),
retrievedStatsHolder.dataCounts,
retrievedStatsHolder.memoryUsage,
@ -320,9 +317,13 @@ public class TransportGetDataFrameAnalyticsStatsAction
private static class RetrievedStatsHolder {
private volatile StoredProgress progress = new StoredProgress(new ProgressTracker().report());
private volatile StoredProgress progress;
private volatile DataCounts dataCounts;
private volatile MemoryUsage memoryUsage;
private volatile AnalysisStats analysisStats;
private RetrievedStatsHolder(List<PhaseProgress> defaultProgress) {
progress = new StoredProgress(defaultProgress);
}
}
}

View File

@ -81,6 +81,11 @@ public class DataFrameAnalyticsManager {
// With config in hand, determine action to take
ActionListener<DataFrameAnalyticsConfig> configListener = ActionListener.wrap(
config -> {
// At this point we have the config at hand and we can reset the progress tracker
// to use the analyses phases. We preserve reindexing progress as if reindexing was
// finished it will not be reset.
task.getStatsHolder().resetProgressTrackerPreservingReindexingProgress(config.getAnalysis().getProgressPhases());
switch(currentState) {
// If we are STARTED, it means the job was started because the start API was called.
// We should determine the job's starting state based on its previous progress.
@ -217,7 +222,6 @@ public class DataFrameAnalyticsManager {
return;
}
task.setReindexingTaskId(null);
task.setReindexingFinished();
auditor.info(
config.getId(),
Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_FINISHED_REINDEXING, config.getDest().getIndex(),
@ -296,6 +300,7 @@ public class DataFrameAnalyticsManager {
task.markAsCompleted();
return;
}
final ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId());
// Update state to ANALYZING and start process
ActionListener<DataFrameDataExtractorFactory> dataExtractorFactoryListener = ActionListener.wrap(
@ -327,8 +332,8 @@ public class DataFrameAnalyticsManager {
ActionListener<RefreshResponse> refreshListener = ActionListener.wrap(
refreshResponse -> {
// Ensure we mark reindexing is finished for the case we are recovering a task that had finished reindexing
task.setReindexingFinished();
// Now we can ensure reindexing progress is complete
task.getStatsHolder().getProgressTracker().updateReindexingProgress(100);
// TODO This could fail with errors. In that case we get stuck with the copied index.
// We could delete the index in case of failure or we could try building the factory before reindexing

View File

@ -67,10 +67,9 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
private final StartDataFrameAnalyticsAction.TaskParams taskParams;
@Nullable
private volatile Long reindexingTaskId;
private volatile boolean isReindexingFinished;
private volatile boolean isStopping;
private volatile boolean isMarkAsCompletedCalled;
private final StatsHolder statsHolder = new StatsHolder();
private final StatsHolder statsHolder;
public DataFrameAnalyticsTask(long id, String type, String action, TaskId parentTask, Map<String, String> headers,
Client client, ClusterService clusterService, DataFrameAnalyticsManager analyticsManager,
@ -81,6 +80,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
this.analyticsManager = Objects.requireNonNull(analyticsManager);
this.auditor = Objects.requireNonNull(auditor);
this.taskParams = Objects.requireNonNull(taskParams);
this.statsHolder = new StatsHolder(taskParams.getProgressOnStart());
}
public StartDataFrameAnalyticsAction.TaskParams getParams() {
@ -92,10 +92,6 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
this.reindexingTaskId = reindexingTaskId;
}
public void setReindexingFinished() {
isReindexingFinished = true;
}
public boolean isStopping() {
return isStopping;
}
@ -222,7 +218,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
// We set reindexing progress at least to 1 for a running process to be able to
// distinguish a job that is running for the first time against a job that is restarting.
reindexTaskProgress -> {
statsHolder.getProgressTracker().reindexingPercent.set(Math.max(1, reindexTaskProgress));
statsHolder.getProgressTracker().updateReindexingProgress(Math.max(1, reindexTaskProgress));
listener.onResponse(null);
},
listener::onFailure
@ -232,9 +228,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
private void getReindexTaskProgress(ActionListener<Integer> listener) {
TaskId reindexTaskId = getReindexTaskId();
if (reindexTaskId == null) {
// The task is not present which means either it has not started yet or it finished.
// We keep track of whether the task has finished so we can use that to tell whether the progress 100.
listener.onResponse(isReindexingFinished ? 100 : 0);
listener.onResponse(statsHolder.getProgressTracker().getReindexingProgressPercent());
return;
}
@ -250,8 +244,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
error -> {
if (ExceptionsHelper.unwrapCause(error) instanceof ResourceNotFoundException) {
// The task is not present which means either it has not started yet or it finished.
// We keep track of whether the task has finished so we can use that to tell whether the progress 100.
listener.onResponse(isReindexingFinished ? 100 : 0);
listener.onResponse(statsHolder.getProgressTracker().getReindexingProgressPercent());
} else {
listener.onFailure(error);
}
@ -365,17 +358,10 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
LOGGER.debug("[{}] Last incomplete progress [{}, {}]", jobId, lastIncompletePhase.getPhase(),
lastIncompletePhase.getProgressPercent());
switch (lastIncompletePhase.getPhase()) {
case ProgressTracker.REINDEXING:
return lastIncompletePhase.getProgressPercent() == 0 ? StartingState.FIRST_TIME : StartingState.RESUMING_REINDEXING;
case ProgressTracker.LOADING_DATA:
case ProgressTracker.ANALYZING:
case ProgressTracker.WRITING_RESULTS:
return StartingState.RESUMING_ANALYZING;
default:
LOGGER.warn("[{}] Unexpected progress phase [{}]", jobId, lastIncompletePhase.getPhase());
return StartingState.FIRST_TIME;
if (ProgressTracker.REINDEXING.equals(lastIncompletePhase.getPhase())) {
return lastIncompletePhase.getProgressPercent() == 0 ? StartingState.FIRST_TIME : StartingState.RESUMING_REINDEXING;
}
return StartingState.RESUMING_ANALYZING;
}
}

View File

@ -246,7 +246,7 @@ public class AnalyticsProcessManager {
}
}
rowsProcessed += rows.get().size();
progressTracker.loadingDataPercent.set(rowsProcessed >= totalRows ? 100 : (int) (rowsProcessed * 100.0 / totalRows));
progressTracker.updateLoadingDataProgress(rowsProcessed >= totalRows ? 100 : (int) (rowsProcessed * 100.0 / totalRows));
}
}
}

View File

@ -153,11 +153,11 @@ public class AnalyticsResultProcessor {
}
private void updateResultsProgress(int progress) {
statsHolder.getProgressTracker().writingResultsPercent.set(Math.min(progress, MAX_PROGRESS_BEFORE_COMPLETION));
statsHolder.getProgressTracker().updateWritingResultsProgress(Math.min(progress, MAX_PROGRESS_BEFORE_COMPLETION));
}
private void completeResultsProgress() {
statsHolder.getProgressTracker().writingResultsPercent.set(100);
statsHolder.getProgressTracker().updateWritingResultsProgress(100);
}
private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJoiner) {
@ -169,7 +169,7 @@ public class AnalyticsResultProcessor {
if (phaseProgress != null) {
LOGGER.debug("[{}] progress for phase [{}] updated to [{}]", analytics.getId(), phaseProgress.getPhase(),
phaseProgress.getProgressPercent());
statsHolder.getProgressTracker().analyzingPercent.set(phaseProgress.getProgressPercent());
statsHolder.getProgressTracker().updatePhase(phaseProgress);
}
TrainedModelDefinition.Builder inferenceModelBuilder = result.getInferenceModelBuilder();
if (inferenceModelBuilder != null) {

View File

@ -5,30 +5,85 @@
*/
package org.elasticsearch.xpack.ml.dataframe.stats;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
/**
* Tracks progress of a data frame analytics job.
* It includes phases "reindexing", "loading_data" and "writing_results"
* and allows for custom phases between "loading_data" and "writing_results".
*/
public class ProgressTracker {
public static final String REINDEXING = "reindexing";
public static final String LOADING_DATA = "loading_data";
public static final String ANALYZING = "analyzing";
public static final String WRITING_RESULTS = "writing_results";
public final AtomicInteger reindexingPercent = new AtomicInteger(0);
public final AtomicInteger loadingDataPercent = new AtomicInteger(0);
public final AtomicInteger analyzingPercent = new AtomicInteger(0);
public final AtomicInteger writingResultsPercent = new AtomicInteger(0);
private final String[] phasesInOrder;
private final Map<String, Integer> progressPercentPerPhase;
public static ProgressTracker fromZeroes(List<String> analysisProgressPhases) {
List<PhaseProgress> phases = new ArrayList<>(3 + analysisProgressPhases.size());
phases.add(new PhaseProgress(REINDEXING, 0));
phases.add(new PhaseProgress(LOADING_DATA, 0));
analysisProgressPhases.forEach(analysisPhase -> phases.add(new PhaseProgress(analysisPhase, 0)));
phases.add(new PhaseProgress(WRITING_RESULTS, 0));
return new ProgressTracker(phases);
}
public ProgressTracker(List<PhaseProgress> phaseProgresses) {
phasesInOrder = new String[phaseProgresses.size()];
progressPercentPerPhase = new ConcurrentHashMap<>();
for (int i = 0; i < phaseProgresses.size(); i++) {
PhaseProgress phaseProgress = phaseProgresses.get(i);
phasesInOrder[i] = phaseProgress.getPhase();
progressPercentPerPhase.put(phaseProgress.getPhase(), phaseProgress.getProgressPercent());
}
assert progressPercentPerPhase.containsKey(REINDEXING);
assert progressPercentPerPhase.containsKey(LOADING_DATA);
assert progressPercentPerPhase.containsKey(WRITING_RESULTS);
}
public void updateReindexingProgress(int progressPercent) {
progressPercentPerPhase.put(REINDEXING, progressPercent);
}
public int getReindexingProgressPercent() {
return progressPercentPerPhase.get(REINDEXING);
}
public void updateLoadingDataProgress(int progressPercent) {
progressPercentPerPhase.put(LOADING_DATA, progressPercent);
}
public void updateWritingResultsProgress(int progressPercent) {
progressPercentPerPhase.put(WRITING_RESULTS, progressPercent);
}
public int getWritingResultsProgressPercent() {
return progressPercentPerPhase.get(WRITING_RESULTS);
}
public void updatePhase(PhaseProgress phase) {
Integer newValue = progressPercentPerPhase.computeIfPresent(phase.getPhase(), (k, v) -> phase.getProgressPercent());
if (newValue == null) {
throw ExceptionsHelper.serverError("unknown progress phase [" + phase.getPhase() + "]");
}
}
public List<PhaseProgress> report() {
return Arrays.asList(
new PhaseProgress(REINDEXING, reindexingPercent.get()),
new PhaseProgress(LOADING_DATA, loadingDataPercent.get()),
new PhaseProgress(ANALYZING, analyzingPercent.get()),
new PhaseProgress(WRITING_RESULTS, writingResultsPercent.get())
);
return Collections.unmodifiableList(Arrays.stream(phasesInOrder)
.map(phase -> new PhaseProgress(phase, progressPercentPerPhase.get(phase)))
.collect(Collectors.toList()));
}
}

View File

@ -7,7 +7,9 @@ package org.elasticsearch.xpack.ml.dataframe.stats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
/**
@ -16,18 +18,24 @@ import java.util.concurrent.atomic.AtomicReference;
*/
public class StatsHolder {
private final ProgressTracker progressTracker;
private volatile ProgressTracker progressTracker;
private final AtomicReference<MemoryUsage> memoryUsageHolder;
private final AtomicReference<AnalysisStats> analysisStatsHolder;
private final DataCountsTracker dataCountsTracker;
public StatsHolder() {
progressTracker = new ProgressTracker();
public StatsHolder(List<PhaseProgress> progressOnStart) {
progressTracker = new ProgressTracker(progressOnStart);
memoryUsageHolder = new AtomicReference<>();
analysisStatsHolder = new AtomicReference<>();
dataCountsTracker = new DataCountsTracker();
}
public void resetProgressTrackerPreservingReindexingProgress(List<String> analysisPhases) {
int reindexingProgressPercent = progressTracker.getReindexingProgressPercent();
progressTracker = ProgressTracker.fromZeroes(analysisPhases);
progressTracker.updateReindexingProgress(reindexingProgressPercent);
}
public ProgressTracker getProgressTracker() {
return progressTracker;
}

View File

@ -22,6 +22,7 @@ import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
@ -94,7 +95,8 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
task = mock(DataFrameAnalyticsTask.class);
when(task.getAllocationId()).thenReturn(TASK_ALLOCATION_ID);
when(task.getStatsHolder()).thenReturn(new StatsHolder());
when(task.getStatsHolder()).thenReturn(new StatsHolder(
ProgressTracker.fromZeroes(Collections.singletonList("analyzing")).report()));
when(task.getParentTaskId()).thenReturn(new TaskId(""));
dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandomBuilder(CONFIG_ID,
false,

View File

@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
import org.elasticsearch.xpack.ml.extractor.DocValueField;
@ -67,7 +68,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
private AnalyticsProcess<AnalyticsResult> process;
private DataFrameRowsJoiner dataFrameRowsJoiner;
private StatsHolder statsHolder = new StatsHolder();
private StatsHolder statsHolder = new StatsHolder(ProgressTracker.fromZeroes(Collections.singletonList("analyzing")).report());
private TrainedModelProvider trainedModelProvider;
private DataFrameAnalyticsAuditor auditor;
private StatsPersister statsPersister;
@ -114,7 +115,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
verify(dataFrameRowsJoiner).close();
Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner);
assertThat(statsHolder.getProgressTracker().writingResultsPercent.get(), equalTo(100));
assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(100));
}
public void testProcess_GivenRowResults() {
@ -132,7 +133,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults1);
inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults2);
assertThat(statsHolder.getProgressTracker().writingResultsPercent.get(), equalTo(100));
assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(100));
}
public void testProcess_GivenDataFrameRowsJoinerFails() {
@ -155,7 +156,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
verify(auditor).error(eq(JOB_ID), auditCaptor.capture());
assertThat(auditCaptor.getValue(), containsString("Error processing results; some failure"));
assertThat(statsHolder.getProgressTracker().writingResultsPercent.get(), equalTo(0));
assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
}
@SuppressWarnings("unchecked")
@ -251,7 +252,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
Mockito.verifyNoMoreInteractions(auditor);
assertThat(resultProcessor.getFailure(), startsWith("error processing results; error storing trained model with id [" + JOB_ID));
assertThat(statsHolder.getProgressTracker().writingResultsPercent.get(), equalTo(0));
assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
}
private void givenProcessResults(List<AnalyticsResult> results) {

View File

@ -0,0 +1,81 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.stats;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
public class ProgressTrackerTests extends ESTestCase {
public void testCtor() {
List<PhaseProgress> phases = Collections.unmodifiableList(
Arrays.asList(
new PhaseProgress("reindexing", 10),
new PhaseProgress("loading_data", 20),
new PhaseProgress("a", 30),
new PhaseProgress("b", 40),
new PhaseProgress("writing_results", 50)
)
);
ProgressTracker progressTracker = new ProgressTracker(phases);
assertThat(progressTracker.report(), equalTo(phases));
}
public void testFromZeroes() {
ProgressTracker progressTracker = ProgressTracker.fromZeroes(Arrays.asList("a", "b", "c"));
List<PhaseProgress> phases = progressTracker.report();
assertThat(phases.size(), equalTo(6));
assertThat(phases.stream().map(PhaseProgress::getPhase).collect(Collectors.toList()),
contains("reindexing", "loading_data", "a", "b", "c", "writing_results"));
assertThat(phases.stream().map(PhaseProgress::getProgressPercent).allMatch(p -> p == 0), is(true));
}
public void testUpdates() {
ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo"));
progressTracker.updateReindexingProgress(1);
progressTracker.updateLoadingDataProgress(2);
progressTracker.updatePhase(new PhaseProgress("foo", 3));
progressTracker.updateWritingResultsProgress(4);
assertThat(progressTracker.getReindexingProgressPercent(), equalTo(1));
assertThat(progressTracker.getWritingResultsProgressPercent(), equalTo(4));
List<PhaseProgress> phases = progressTracker.report();
assertThat(phases.size(), equalTo(4));
assertThat(phases.stream().map(PhaseProgress::getPhase).collect(Collectors.toList()),
contains("reindexing", "loading_data", "foo", "writing_results"));
assertThat(phases.get(0).getProgressPercent(), equalTo(1));
assertThat(phases.get(1).getProgressPercent(), equalTo(2));
assertThat(phases.get(2).getProgressPercent(), equalTo(3));
assertThat(phases.get(3).getProgressPercent(), equalTo(4));
}
public void testUpdatePhase_GivenUnknownPhase() {
ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo"));
ElasticsearchException e = expectThrows(ElasticsearchException.class,
() -> progressTracker.updatePhase(new PhaseProgress("bar", 42)));
assertThat(e.getMessage(), equalTo("unknown progress phase [bar]"));
}
}

View File

@ -0,0 +1,73 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.stats;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.equalTo;
public class StatsHolderTests extends ESTestCase {
public void testResetProgressTrackerPreservingReindexingProgress_GivenSameAnalysisPhases() {
List<PhaseProgress> phases = Collections.unmodifiableList(
Arrays.asList(
new org.elasticsearch.xpack.core.ml.utils.PhaseProgress("reindexing", 10),
new org.elasticsearch.xpack.core.ml.utils.PhaseProgress("loading_data", 20),
new org.elasticsearch.xpack.core.ml.utils.PhaseProgress("a", 30),
new org.elasticsearch.xpack.core.ml.utils.PhaseProgress("b", 40),
new PhaseProgress("writing_results", 50)
)
);
StatsHolder statsHolder = new StatsHolder(phases);
statsHolder.resetProgressTrackerPreservingReindexingProgress(Arrays.asList("a", "b"));
List<PhaseProgress> phaseProgresses = statsHolder.getProgressTracker().report();
assertThat(phaseProgresses.size(), equalTo(5));
assertThat(phaseProgresses.stream().map(PhaseProgress::getPhase).collect(Collectors.toList()),
contains("reindexing", "loading_data", "a", "b", "writing_results"));
assertThat(phaseProgresses.get(0).getProgressPercent(), equalTo(10));
assertThat(phaseProgresses.get(1).getProgressPercent(), equalTo(0));
assertThat(phaseProgresses.get(2).getProgressPercent(), equalTo(0));
assertThat(phaseProgresses.get(3).getProgressPercent(), equalTo(0));
assertThat(phaseProgresses.get(4).getProgressPercent(), equalTo(0));
}
public void testResetProgressTrackerPreservingReindexingProgress_GivenDifferentAnalysisPhases() {
List<PhaseProgress> phases = Collections.unmodifiableList(
Arrays.asList(
new org.elasticsearch.xpack.core.ml.utils.PhaseProgress("reindexing", 10),
new org.elasticsearch.xpack.core.ml.utils.PhaseProgress("loading_data", 20),
new org.elasticsearch.xpack.core.ml.utils.PhaseProgress("a", 30),
new org.elasticsearch.xpack.core.ml.utils.PhaseProgress("b", 40),
new PhaseProgress("writing_results", 50)
)
);
StatsHolder statsHolder = new StatsHolder(phases);
statsHolder.resetProgressTrackerPreservingReindexingProgress(Arrays.asList("c", "d"));
List<PhaseProgress> phaseProgresses = statsHolder.getProgressTracker().report();
assertThat(phaseProgresses.size(), equalTo(5));
assertThat(phaseProgresses.stream().map(PhaseProgress::getPhase).collect(Collectors.toList()),
contains("reindexing", "loading_data", "c", "d", "writing_results"));
assertThat(phaseProgresses.get(0).getProgressPercent(), equalTo(10));
assertThat(phaseProgresses.get(1).getProgressPercent(), equalTo(0));
assertThat(phaseProgresses.get(2).getProgressPercent(), equalTo(0));
assertThat(phaseProgresses.get(3).getProgressPercent(), equalTo(0));
assertThat(phaseProgresses.get(4).getProgressPercent(), equalTo(0));
}
}