From 7f100c1196276204a2e026c2ac87d632ced601e6 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 27 Apr 2020 13:30:05 +0300 Subject: [PATCH] [7.x][ML] Allow analytics process define its own progress phases (#55763) (#55791) This is a continuation from #55580. Now that we're parsing phase progresses from the analytics process we change `ProgressTracker` to allow for custom phases between the `loading_data` and `writing_results` phases. Each `DataFrameAnalysis` may declare its own phases. This commit sets things in place for the analytics process to start reporting different phases per analysis type. However, this is still preserving existing behaviour as all analyses currently declare a single `analyzing` phase. Backport of #55763 --- .../ml/dataframe/analyses/Classification.java | 5 ++ .../dataframe/analyses/DataFrameAnalysis.java | 5 ++ .../dataframe/analyses/OutlierDetection.java | 5 ++ .../ml/dataframe/analyses/Regression.java | 5 ++ ...sportGetDataFrameAnalyticsStatsAction.java | 73 ++++++++--------- .../dataframe/DataFrameAnalyticsManager.java | 11 ++- .../ml/dataframe/DataFrameAnalyticsTask.java | 30 ++----- .../process/AnalyticsProcessManager.java | 2 +- .../process/AnalyticsResultProcessor.java | 6 +- .../ml/dataframe/stats/ProgressTracker.java | 79 +++++++++++++++--- .../xpack/ml/dataframe/stats/StatsHolder.java | 14 +++- .../process/AnalyticsProcessManagerTests.java | 4 +- .../AnalyticsResultProcessorTests.java | 11 +-- .../dataframe/stats/ProgressTrackerTests.java | 81 +++++++++++++++++++ .../ml/dataframe/stats/StatsHolderTests.java | 73 +++++++++++++++++ 15 files changed, 318 insertions(+), 86 deletions(-) create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 93ad7e7d85d..83264c66746 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -348,6 +348,11 @@ public class Classification implements DataFrameAnalysis { return jobId + STATE_DOC_ID_SUFFIX; } + @Override + public List getProgressPhases() { + return Collections.singletonList("analyzing"); + } + public static String extractJobIdFromStateDoc(String stateDocId) { int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_SUFFIX); return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java index 941224dc30a..f8521d004a2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java @@ -66,6 +66,11 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { */ String getStateDocId(String jobId); + /** + * Returns the progress phases the analysis goes through in order + */ + List getProgressPhases(); + /** * Summarizes information about the fields that is necessary for analysis to generate * the parameters needed for the process configuration. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java index 2c83afa8780..2d955c9c754 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -249,6 +249,11 @@ public class OutlierDetection implements DataFrameAnalysis { throw new UnsupportedOperationException("Outlier detection does not support state"); } + @Override + public List getProgressPhases() { + return Collections.singletonList("analyzing"); + } + public enum Method { LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index 824d4f95a17..75a1e83da83 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -213,6 +213,11 @@ public class Regression implements DataFrameAnalysis { return jobId + STATE_DOC_ID_SUFFIX; } + @Override + public List getProgressPhases() { + return Collections.singletonList("analyzing"); + } + public static String extractJobIdFromStateDoc(String stateDocId) { int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_SUFFIX); return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java index a8a40ef4c59..03eff724781 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java @@ -44,9 +44,9 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields; -import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage; import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; @@ -55,7 +55,6 @@ import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask; import org.elasticsearch.xpack.ml.dataframe.StoredProgress; import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; -import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; import org.elasticsearch.xpack.ml.utils.persistence.MlParserUtils; import java.util.ArrayList; @@ -105,25 +104,20 @@ public class TransportGetDataFrameAnalyticsStatsAction ActionListener> listener) { logger.debug("Get stats for running task [{}]", task.getParams().getId()); - ActionListener statsHolderListener = ActionListener.wrap( - statsHolder -> { + ActionListener reindexingProgressListener = ActionListener.wrap( + aVoid -> { Stats stats = buildStats( task.getParams().getId(), - statsHolder.getProgressTracker().report(), - statsHolder.getDataCountsTracker().report(task.getParams().getId()), - statsHolder.getMemoryUsage(), - statsHolder.getAnalysisStats() + task.getStatsHolder().getProgressTracker().report(), + task.getStatsHolder().getDataCountsTracker().report(task.getParams().getId()), + task.getStatsHolder().getMemoryUsage(), + task.getStatsHolder().getAnalysisStats() ); listener.onResponse(new QueryPage<>(Collections.singletonList(stats), 1, GetDataFrameAnalyticsAction.Response.RESULTS_FIELD)); }, listener::onFailure ); - ActionListener reindexingProgressListener = ActionListener.wrap( - aVoid -> statsHolderListener.onResponse(task.getStatsHolder()), - listener::onFailure - ); - task.updateReindexTaskProgress(reindexingProgressListener); } @@ -138,7 +132,7 @@ public class TransportGetDataFrameAnalyticsStatsAction .collect(Collectors.toList()); request.setExpandedIds(expandedIds); ActionListener runningTasksStatsListener = ActionListener.wrap( - runningTasksStatsResponse -> gatherStatsForStoppedTasks(request.getExpandedIds(), runningTasksStatsResponse, + runningTasksStatsResponse -> gatherStatsForStoppedTasks(getResponse.getResources().results(), runningTasksStatsResponse, ActionListener.wrap( finalResponse -> { @@ -163,20 +157,20 @@ public class TransportGetDataFrameAnalyticsStatsAction executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsAction.INSTANCE, getRequest, getResponseListener); } - void gatherStatsForStoppedTasks(List expandedIds, GetDataFrameAnalyticsStatsAction.Response runningTasksResponse, + void gatherStatsForStoppedTasks(List configs, GetDataFrameAnalyticsStatsAction.Response runningTasksResponse, ActionListener listener) { - List stoppedTasksIds = determineStoppedTasksIds(expandedIds, runningTasksResponse.getResponse().results()); - if (stoppedTasksIds.isEmpty()) { + List stoppedConfigs = determineStoppedConfigs(configs, runningTasksResponse.getResponse().results()); + if (stoppedConfigs.isEmpty()) { listener.onResponse(runningTasksResponse); return; } - AtomicInteger counter = new AtomicInteger(stoppedTasksIds.size()); - AtomicArray jobStats = new AtomicArray<>(stoppedTasksIds.size()); - for (int i = 0; i < stoppedTasksIds.size(); i++) { + AtomicInteger counter = new AtomicInteger(stoppedConfigs.size()); + AtomicArray jobStats = new AtomicArray<>(stoppedConfigs.size()); + for (int i = 0; i < stoppedConfigs.size(); i++) { final int slot = i; - String jobId = stoppedTasksIds.get(i); - searchStats(jobId, ActionListener.wrap( + DataFrameAnalyticsConfig config = stoppedConfigs.get(i); + searchStats(config, ActionListener.wrap( stats -> { jobStats.set(slot, stats); if (counter.decrementAndGet() == 0) { @@ -192,21 +186,24 @@ public class TransportGetDataFrameAnalyticsStatsAction } } - static List determineStoppedTasksIds(List expandedIds, List runningTasksStats) { + static List determineStoppedConfigs(List configs, List runningTasksStats) { Set startedTasksIds = runningTasksStats.stream().map(Stats::getId).collect(Collectors.toSet()); - return expandedIds.stream().filter(id -> startedTasksIds.contains(id) == false).collect(Collectors.toList()); + return configs.stream().filter(config -> startedTasksIds.contains(config.getId()) == false).collect(Collectors.toList()); } - private void searchStats(String configId, ActionListener listener) { - RetrievedStatsHolder retrievedStatsHolder = new RetrievedStatsHolder(); + private void searchStats(DataFrameAnalyticsConfig config, ActionListener listener) { + logger.debug("[{}] Gathering stats for stopped task", config.getId()); + + RetrievedStatsHolder retrievedStatsHolder = new RetrievedStatsHolder( + ProgressTracker.fromZeroes(config.getAnalysis().getProgressPhases()).report()); MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); - multiSearchRequest.add(buildStoredProgressSearch(configId)); - multiSearchRequest.add(buildStatsDocSearch(configId, DataCounts.TYPE_VALUE)); - multiSearchRequest.add(buildStatsDocSearch(configId, MemoryUsage.TYPE_VALUE)); - multiSearchRequest.add(buildStatsDocSearch(configId, OutlierDetectionStats.TYPE_VALUE)); - multiSearchRequest.add(buildStatsDocSearch(configId, ClassificationStats.TYPE_VALUE)); - multiSearchRequest.add(buildStatsDocSearch(configId, RegressionStats.TYPE_VALUE)); + multiSearchRequest.add(buildStoredProgressSearch(config.getId())); + multiSearchRequest.add(buildStatsDocSearch(config.getId(), DataCounts.TYPE_VALUE)); + multiSearchRequest.add(buildStatsDocSearch(config.getId(), MemoryUsage.TYPE_VALUE)); + multiSearchRequest.add(buildStatsDocSearch(config.getId(), OutlierDetectionStats.TYPE_VALUE)); + multiSearchRequest.add(buildStatsDocSearch(config.getId(), ClassificationStats.TYPE_VALUE)); + multiSearchRequest.add(buildStatsDocSearch(config.getId(), RegressionStats.TYPE_VALUE)); executeAsyncWithOrigin(client, ML_ORIGIN, MultiSearchAction.INSTANCE, multiSearchRequest, ActionListener.wrap( multiSearchResponse -> { @@ -218,7 +215,7 @@ public class TransportGetDataFrameAnalyticsStatsAction logger.error( new ParameterizedMessage( "[{}] Item failure encountered during multi search for request [indices={}, source={}]: {}", - configId, itemRequest.indices(), itemRequest.source(), itemResponse.getFailureMessage()), + config.getId(), itemRequest.indices(), itemRequest.source(), itemResponse.getFailureMessage()), itemResponse.getFailure()); listener.onFailure(ExceptionsHelper.serverError(itemResponse.getFailureMessage(), itemResponse.getFailure())); return; @@ -227,13 +224,13 @@ public class TransportGetDataFrameAnalyticsStatsAction if (hits.length == 0) { // Not found } else if (hits.length == 1) { - parseHit(hits[0], configId, retrievedStatsHolder); + parseHit(hits[0], config.getId(), retrievedStatsHolder); } else { throw ExceptionsHelper.serverError("Found [" + hits.length + "] hits when just one was requested"); } } } - listener.onResponse(buildStats(configId, + listener.onResponse(buildStats(config.getId(), retrievedStatsHolder.progress.get(), retrievedStatsHolder.dataCounts, retrievedStatsHolder.memoryUsage, @@ -320,9 +317,13 @@ public class TransportGetDataFrameAnalyticsStatsAction private static class RetrievedStatsHolder { - private volatile StoredProgress progress = new StoredProgress(new ProgressTracker().report()); + private volatile StoredProgress progress; private volatile DataCounts dataCounts; private volatile MemoryUsage memoryUsage; private volatile AnalysisStats analysisStats; + + private RetrievedStatsHolder(List defaultProgress) { + progress = new StoredProgress(defaultProgress); + } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index 5afb9f687ec..22c7ab0a25a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -81,6 +81,11 @@ public class DataFrameAnalyticsManager { // With config in hand, determine action to take ActionListener configListener = ActionListener.wrap( config -> { + // At this point we have the config at hand and we can reset the progress tracker + // to use the analyses phases. We preserve reindexing progress as if reindexing was + // finished it will not be reset. + task.getStatsHolder().resetProgressTrackerPreservingReindexingProgress(config.getAnalysis().getProgressPhases()); + switch(currentState) { // If we are STARTED, it means the job was started because the start API was called. // We should determine the job's starting state based on its previous progress. @@ -217,7 +222,6 @@ public class DataFrameAnalyticsManager { return; } task.setReindexingTaskId(null); - task.setReindexingFinished(); auditor.info( config.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_FINISHED_REINDEXING, config.getDest().getIndex(), @@ -296,6 +300,7 @@ public class DataFrameAnalyticsManager { task.markAsCompleted(); return; } + final ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId()); // Update state to ANALYZING and start process ActionListener dataExtractorFactoryListener = ActionListener.wrap( @@ -327,8 +332,8 @@ public class DataFrameAnalyticsManager { ActionListener refreshListener = ActionListener.wrap( refreshResponse -> { - // Ensure we mark reindexing is finished for the case we are recovering a task that had finished reindexing - task.setReindexingFinished(); + // Now we can ensure reindexing progress is complete + task.getStatsHolder().getProgressTracker().updateReindexingProgress(100); // TODO This could fail with errors. In that case we get stuck with the copied index. // We could delete the index in case of failure or we could try building the factory before reindexing diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java index b43841cda9f..5ff9ede8673 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java @@ -67,10 +67,9 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S private final StartDataFrameAnalyticsAction.TaskParams taskParams; @Nullable private volatile Long reindexingTaskId; - private volatile boolean isReindexingFinished; private volatile boolean isStopping; private volatile boolean isMarkAsCompletedCalled; - private final StatsHolder statsHolder = new StatsHolder(); + private final StatsHolder statsHolder; public DataFrameAnalyticsTask(long id, String type, String action, TaskId parentTask, Map headers, Client client, ClusterService clusterService, DataFrameAnalyticsManager analyticsManager, @@ -81,6 +80,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S this.analyticsManager = Objects.requireNonNull(analyticsManager); this.auditor = Objects.requireNonNull(auditor); this.taskParams = Objects.requireNonNull(taskParams); + this.statsHolder = new StatsHolder(taskParams.getProgressOnStart()); } public StartDataFrameAnalyticsAction.TaskParams getParams() { @@ -92,10 +92,6 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S this.reindexingTaskId = reindexingTaskId; } - public void setReindexingFinished() { - isReindexingFinished = true; - } - public boolean isStopping() { return isStopping; } @@ -222,7 +218,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S // We set reindexing progress at least to 1 for a running process to be able to // distinguish a job that is running for the first time against a job that is restarting. reindexTaskProgress -> { - statsHolder.getProgressTracker().reindexingPercent.set(Math.max(1, reindexTaskProgress)); + statsHolder.getProgressTracker().updateReindexingProgress(Math.max(1, reindexTaskProgress)); listener.onResponse(null); }, listener::onFailure @@ -232,9 +228,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S private void getReindexTaskProgress(ActionListener listener) { TaskId reindexTaskId = getReindexTaskId(); if (reindexTaskId == null) { - // The task is not present which means either it has not started yet or it finished. - // We keep track of whether the task has finished so we can use that to tell whether the progress 100. - listener.onResponse(isReindexingFinished ? 100 : 0); + listener.onResponse(statsHolder.getProgressTracker().getReindexingProgressPercent()); return; } @@ -250,8 +244,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S error -> { if (ExceptionsHelper.unwrapCause(error) instanceof ResourceNotFoundException) { // The task is not present which means either it has not started yet or it finished. - // We keep track of whether the task has finished so we can use that to tell whether the progress 100. - listener.onResponse(isReindexingFinished ? 100 : 0); + listener.onResponse(statsHolder.getProgressTracker().getReindexingProgressPercent()); } else { listener.onFailure(error); } @@ -365,17 +358,10 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S LOGGER.debug("[{}] Last incomplete progress [{}, {}]", jobId, lastIncompletePhase.getPhase(), lastIncompletePhase.getProgressPercent()); - switch (lastIncompletePhase.getPhase()) { - case ProgressTracker.REINDEXING: - return lastIncompletePhase.getProgressPercent() == 0 ? StartingState.FIRST_TIME : StartingState.RESUMING_REINDEXING; - case ProgressTracker.LOADING_DATA: - case ProgressTracker.ANALYZING: - case ProgressTracker.WRITING_RESULTS: - return StartingState.RESUMING_ANALYZING; - default: - LOGGER.warn("[{}] Unexpected progress phase [{}]", jobId, lastIncompletePhase.getPhase()); - return StartingState.FIRST_TIME; + if (ProgressTracker.REINDEXING.equals(lastIncompletePhase.getPhase())) { + return lastIncompletePhase.getProgressPercent() == 0 ? StartingState.FIRST_TIME : StartingState.RESUMING_REINDEXING; } + return StartingState.RESUMING_ANALYZING; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 5ad1eda8e4f..9bb77947eab 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -246,7 +246,7 @@ public class AnalyticsProcessManager { } } rowsProcessed += rows.get().size(); - progressTracker.loadingDataPercent.set(rowsProcessed >= totalRows ? 100 : (int) (rowsProcessed * 100.0 / totalRows)); + progressTracker.updateLoadingDataProgress(rowsProcessed >= totalRows ? 100 : (int) (rowsProcessed * 100.0 / totalRows)); } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 8921ee39249..cd9ad2baf1f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -153,11 +153,11 @@ public class AnalyticsResultProcessor { } private void updateResultsProgress(int progress) { - statsHolder.getProgressTracker().writingResultsPercent.set(Math.min(progress, MAX_PROGRESS_BEFORE_COMPLETION)); + statsHolder.getProgressTracker().updateWritingResultsProgress(Math.min(progress, MAX_PROGRESS_BEFORE_COMPLETION)); } private void completeResultsProgress() { - statsHolder.getProgressTracker().writingResultsPercent.set(100); + statsHolder.getProgressTracker().updateWritingResultsProgress(100); } private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJoiner) { @@ -169,7 +169,7 @@ public class AnalyticsResultProcessor { if (phaseProgress != null) { LOGGER.debug("[{}] progress for phase [{}] updated to [{}]", analytics.getId(), phaseProgress.getPhase(), phaseProgress.getProgressPercent()); - statsHolder.getProgressTracker().analyzingPercent.set(phaseProgress.getProgressPercent()); + statsHolder.getProgressTracker().updatePhase(phaseProgress); } TrainedModelDefinition.Builder inferenceModelBuilder = result.getInferenceModelBuilder(); if (inferenceModelBuilder != null) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java index 0c627072105..0f731e08307 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java @@ -5,30 +5,85 @@ */ package org.elasticsearch.xpack.ml.dataframe.stats; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; +/** + * Tracks progress of a data frame analytics job. + * It includes phases "reindexing", "loading_data" and "writing_results" + * and allows for custom phases between "loading_data" and "writing_results". + */ public class ProgressTracker { public static final String REINDEXING = "reindexing"; public static final String LOADING_DATA = "loading_data"; - public static final String ANALYZING = "analyzing"; public static final String WRITING_RESULTS = "writing_results"; - public final AtomicInteger reindexingPercent = new AtomicInteger(0); - public final AtomicInteger loadingDataPercent = new AtomicInteger(0); - public final AtomicInteger analyzingPercent = new AtomicInteger(0); - public final AtomicInteger writingResultsPercent = new AtomicInteger(0); + private final String[] phasesInOrder; + private final Map progressPercentPerPhase; + + public static ProgressTracker fromZeroes(List analysisProgressPhases) { + List phases = new ArrayList<>(3 + analysisProgressPhases.size()); + phases.add(new PhaseProgress(REINDEXING, 0)); + phases.add(new PhaseProgress(LOADING_DATA, 0)); + analysisProgressPhases.forEach(analysisPhase -> phases.add(new PhaseProgress(analysisPhase, 0))); + phases.add(new PhaseProgress(WRITING_RESULTS, 0)); + return new ProgressTracker(phases); + } + + public ProgressTracker(List phaseProgresses) { + phasesInOrder = new String[phaseProgresses.size()]; + progressPercentPerPhase = new ConcurrentHashMap<>(); + + for (int i = 0; i < phaseProgresses.size(); i++) { + PhaseProgress phaseProgress = phaseProgresses.get(i); + phasesInOrder[i] = phaseProgress.getPhase(); + progressPercentPerPhase.put(phaseProgress.getPhase(), phaseProgress.getProgressPercent()); + } + + assert progressPercentPerPhase.containsKey(REINDEXING); + assert progressPercentPerPhase.containsKey(LOADING_DATA); + assert progressPercentPerPhase.containsKey(WRITING_RESULTS); + } + + public void updateReindexingProgress(int progressPercent) { + progressPercentPerPhase.put(REINDEXING, progressPercent); + } + + public int getReindexingProgressPercent() { + return progressPercentPerPhase.get(REINDEXING); + } + + public void updateLoadingDataProgress(int progressPercent) { + progressPercentPerPhase.put(LOADING_DATA, progressPercent); + } + + public void updateWritingResultsProgress(int progressPercent) { + progressPercentPerPhase.put(WRITING_RESULTS, progressPercent); + } + + public int getWritingResultsProgressPercent() { + return progressPercentPerPhase.get(WRITING_RESULTS); + } + + public void updatePhase(PhaseProgress phase) { + Integer newValue = progressPercentPerPhase.computeIfPresent(phase.getPhase(), (k, v) -> phase.getProgressPercent()); + if (newValue == null) { + throw ExceptionsHelper.serverError("unknown progress phase [" + phase.getPhase() + "]"); + } + } public List report() { - return Arrays.asList( - new PhaseProgress(REINDEXING, reindexingPercent.get()), - new PhaseProgress(LOADING_DATA, loadingDataPercent.get()), - new PhaseProgress(ANALYZING, analyzingPercent.get()), - new PhaseProgress(WRITING_RESULTS, writingResultsPercent.get()) - ); + return Collections.unmodifiableList(Arrays.stream(phasesInOrder) + .map(phase -> new PhaseProgress(phase, progressPercentPerPhase.get(phase))) + .collect(Collectors.toList())); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java index d6d23602123..d1724bd2f11 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolder.java @@ -7,7 +7,9 @@ package org.elasticsearch.xpack.ml.dataframe.stats; import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; +import java.util.List; import java.util.concurrent.atomic.AtomicReference; /** @@ -16,18 +18,24 @@ import java.util.concurrent.atomic.AtomicReference; */ public class StatsHolder { - private final ProgressTracker progressTracker; + private volatile ProgressTracker progressTracker; private final AtomicReference memoryUsageHolder; private final AtomicReference analysisStatsHolder; private final DataCountsTracker dataCountsTracker; - public StatsHolder() { - progressTracker = new ProgressTracker(); + public StatsHolder(List progressOnStart) { + progressTracker = new ProgressTracker(progressOnStart); memoryUsageHolder = new AtomicReference<>(); analysisStatsHolder = new AtomicReference<>(); dataCountsTracker = new DataCountsTracker(); } + public void resetProgressTrackerPreservingReindexingProgress(List analysisPhases) { + int reindexingProgressPercent = progressTracker.getReindexingProgressPercent(); + progressTracker = ProgressTracker.fromZeroes(analysisPhases); + progressTracker.updateReindexingProgress(reindexingProgressPercent); + } + public ProgressTracker getProgressTracker() { return progressTracker; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java index 2c62ebbfa0b..d8994fe3d8f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; +import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -94,7 +95,8 @@ public class AnalyticsProcessManagerTests extends ESTestCase { task = mock(DataFrameAnalyticsTask.class); when(task.getAllocationId()).thenReturn(TASK_ALLOCATION_ID); - when(task.getStatsHolder()).thenReturn(new StatsHolder()); + when(task.getStatsHolder()).thenReturn(new StatsHolder( + ProgressTracker.fromZeroes(Collections.singletonList("analyzing")).report())); when(task.getParentTaskId()).thenReturn(new TaskId("")); dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandomBuilder(CONFIG_ID, false, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 7590d9a5a76..530bf280aaf 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.security.user.XPackUser; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; +import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; import org.elasticsearch.xpack.ml.extractor.DocValueField; @@ -67,7 +68,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { private AnalyticsProcess process; private DataFrameRowsJoiner dataFrameRowsJoiner; - private StatsHolder statsHolder = new StatsHolder(); + private StatsHolder statsHolder = new StatsHolder(ProgressTracker.fromZeroes(Collections.singletonList("analyzing")).report()); private TrainedModelProvider trainedModelProvider; private DataFrameAnalyticsAuditor auditor; private StatsPersister statsPersister; @@ -114,7 +115,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { verify(dataFrameRowsJoiner).close(); Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner); - assertThat(statsHolder.getProgressTracker().writingResultsPercent.get(), equalTo(100)); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(100)); } public void testProcess_GivenRowResults() { @@ -132,7 +133,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults1); inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults2); - assertThat(statsHolder.getProgressTracker().writingResultsPercent.get(), equalTo(100)); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(100)); } public void testProcess_GivenDataFrameRowsJoinerFails() { @@ -155,7 +156,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { verify(auditor).error(eq(JOB_ID), auditCaptor.capture()); assertThat(auditCaptor.getValue(), containsString("Error processing results; some failure")); - assertThat(statsHolder.getProgressTracker().writingResultsPercent.get(), equalTo(0)); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); } @SuppressWarnings("unchecked") @@ -251,7 +252,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { Mockito.verifyNoMoreInteractions(auditor); assertThat(resultProcessor.getFailure(), startsWith("error processing results; error storing trained model with id [" + JOB_ID)); - assertThat(statsHolder.getProgressTracker().writingResultsPercent.get(), equalTo(0)); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); } private void givenProcessResults(List results) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java new file mode 100644 index 00000000000..eec0ac50129 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.stats; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class ProgressTrackerTests extends ESTestCase { + + public void testCtor() { + List phases = Collections.unmodifiableList( + Arrays.asList( + new PhaseProgress("reindexing", 10), + new PhaseProgress("loading_data", 20), + new PhaseProgress("a", 30), + new PhaseProgress("b", 40), + new PhaseProgress("writing_results", 50) + ) + ); + + ProgressTracker progressTracker = new ProgressTracker(phases); + + assertThat(progressTracker.report(), equalTo(phases)); + } + + public void testFromZeroes() { + ProgressTracker progressTracker = ProgressTracker.fromZeroes(Arrays.asList("a", "b", "c")); + + List phases = progressTracker.report(); + + assertThat(phases.size(), equalTo(6)); + assertThat(phases.stream().map(PhaseProgress::getPhase).collect(Collectors.toList()), + contains("reindexing", "loading_data", "a", "b", "c", "writing_results")); + assertThat(phases.stream().map(PhaseProgress::getProgressPercent).allMatch(p -> p == 0), is(true)); + } + + public void testUpdates() { + ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo")); + + progressTracker.updateReindexingProgress(1); + progressTracker.updateLoadingDataProgress(2); + progressTracker.updatePhase(new PhaseProgress("foo", 3)); + progressTracker.updateWritingResultsProgress(4); + + assertThat(progressTracker.getReindexingProgressPercent(), equalTo(1)); + assertThat(progressTracker.getWritingResultsProgressPercent(), equalTo(4)); + + List phases = progressTracker.report(); + + assertThat(phases.size(), equalTo(4)); + assertThat(phases.stream().map(PhaseProgress::getPhase).collect(Collectors.toList()), + contains("reindexing", "loading_data", "foo", "writing_results")); + assertThat(phases.get(0).getProgressPercent(), equalTo(1)); + assertThat(phases.get(1).getProgressPercent(), equalTo(2)); + assertThat(phases.get(2).getProgressPercent(), equalTo(3)); + assertThat(phases.get(3).getProgressPercent(), equalTo(4)); + } + + public void testUpdatePhase_GivenUnknownPhase() { + ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo")); + + ElasticsearchException e = expectThrows(ElasticsearchException.class, + () -> progressTracker.updatePhase(new PhaseProgress("bar", 42))); + + assertThat(e.getMessage(), equalTo("unknown progress phase [bar]")); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java new file mode 100644 index 00000000000..39736cc9068 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsHolderTests.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.stats; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; + +public class StatsHolderTests extends ESTestCase { + + public void testResetProgressTrackerPreservingReindexingProgress_GivenSameAnalysisPhases() { + List phases = Collections.unmodifiableList( + Arrays.asList( + new org.elasticsearch.xpack.core.ml.utils.PhaseProgress("reindexing", 10), + new org.elasticsearch.xpack.core.ml.utils.PhaseProgress("loading_data", 20), + new org.elasticsearch.xpack.core.ml.utils.PhaseProgress("a", 30), + new org.elasticsearch.xpack.core.ml.utils.PhaseProgress("b", 40), + new PhaseProgress("writing_results", 50) + ) + ); + StatsHolder statsHolder = new StatsHolder(phases); + + statsHolder.resetProgressTrackerPreservingReindexingProgress(Arrays.asList("a", "b")); + + List phaseProgresses = statsHolder.getProgressTracker().report(); + + assertThat(phaseProgresses.size(), equalTo(5)); + assertThat(phaseProgresses.stream().map(PhaseProgress::getPhase).collect(Collectors.toList()), + contains("reindexing", "loading_data", "a", "b", "writing_results")); + assertThat(phaseProgresses.get(0).getProgressPercent(), equalTo(10)); + assertThat(phaseProgresses.get(1).getProgressPercent(), equalTo(0)); + assertThat(phaseProgresses.get(2).getProgressPercent(), equalTo(0)); + assertThat(phaseProgresses.get(3).getProgressPercent(), equalTo(0)); + assertThat(phaseProgresses.get(4).getProgressPercent(), equalTo(0)); + } + + public void testResetProgressTrackerPreservingReindexingProgress_GivenDifferentAnalysisPhases() { + List phases = Collections.unmodifiableList( + Arrays.asList( + new org.elasticsearch.xpack.core.ml.utils.PhaseProgress("reindexing", 10), + new org.elasticsearch.xpack.core.ml.utils.PhaseProgress("loading_data", 20), + new org.elasticsearch.xpack.core.ml.utils.PhaseProgress("a", 30), + new org.elasticsearch.xpack.core.ml.utils.PhaseProgress("b", 40), + new PhaseProgress("writing_results", 50) + ) + ); + StatsHolder statsHolder = new StatsHolder(phases); + + statsHolder.resetProgressTrackerPreservingReindexingProgress(Arrays.asList("c", "d")); + + List phaseProgresses = statsHolder.getProgressTracker().report(); + + assertThat(phaseProgresses.size(), equalTo(5)); + assertThat(phaseProgresses.stream().map(PhaseProgress::getPhase).collect(Collectors.toList()), + contains("reindexing", "loading_data", "c", "d", "writing_results")); + assertThat(phaseProgresses.get(0).getProgressPercent(), equalTo(10)); + assertThat(phaseProgresses.get(1).getProgressPercent(), equalTo(0)); + assertThat(phaseProgresses.get(2).getProgressPercent(), equalTo(0)); + assertThat(phaseProgresses.get(3).getProgressPercent(), equalTo(0)); + assertThat(phaseProgresses.get(4).getProgressPercent(), equalTo(0)); + } +}