diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgress.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgress.java index 0f9617bceb1..0eaeff4fac0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgress.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgress.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.utils; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -80,4 +81,9 @@ public class PhaseProgress implements ToXContentObject, Writeable { PhaseProgress that = (PhaseProgress) o; return Objects.equals(phase, that.phase) && progressPercent == that.progressPercent; } + + @Override + public String toString() { + return Strings.toString(this); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java index f46e588facc..c8ce57e3bb4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTracker.java @@ -54,7 +54,7 @@ public class ProgressTracker { } public void updateReindexingProgress(int progressPercent) { - progressPercentPerPhase.put(REINDEXING, progressPercent); + updatePhase(REINDEXING, progressPercent); } public int getReindexingProgressPercent() { @@ -62,11 +62,15 @@ public class ProgressTracker { } public void updateLoadingDataProgress(int progressPercent) { - progressPercentPerPhase.put(LOADING_DATA, progressPercent); + updatePhase(LOADING_DATA, progressPercent); + } + + public int getLoadingDataProgressPercent() { + return progressPercentPerPhase.get(LOADING_DATA); } public void updateWritingResultsProgress(int progressPercent) { - progressPercentPerPhase.put(WRITING_RESULTS, progressPercent); + updatePhase(WRITING_RESULTS, progressPercent); } public int getWritingResultsProgressPercent() { @@ -74,7 +78,11 @@ public class ProgressTracker { } public void updatePhase(PhaseProgress phase) { - progressPercentPerPhase.computeIfPresent(phase.getPhase(), (k, v) -> phase.getProgressPercent()); + updatePhase(phase.getPhase(), phase.getProgressPercent()); + } + + private void updatePhase(String phase, int progress) { + progressPercentPerPhase.computeIfPresent(phase, (k, v) -> Math.max(v, progress)); } public List report() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java index 0706a5dbee8..c0890f59e23 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/stats/ProgressTrackerTests.java @@ -79,4 +79,56 @@ public class ProgressTrackerTests extends ESTestCase { assertThat(phases.stream().map(PhaseProgress::getPhase).collect(Collectors.toList()), contains("reindexing", "loading_data", "foo", "writing_results")); } + + public void testUpdateReindexingProgress_GivenLowerValueThanCurrentProgress() { + ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo")); + + progressTracker.updateReindexingProgress(10); + + progressTracker.updateReindexingProgress(11); + assertThat(progressTracker.getReindexingProgressPercent(), equalTo(11)); + + progressTracker.updateReindexingProgress(10); + assertThat(progressTracker.getReindexingProgressPercent(), equalTo(11)); + } + + public void testUpdateLoadingDataProgress_GivenLowerValueThanCurrentProgress() { + ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo")); + + progressTracker.updateLoadingDataProgress(20); + + progressTracker.updateLoadingDataProgress(21); + assertThat(progressTracker.getLoadingDataProgressPercent(), equalTo(21)); + + progressTracker.updateLoadingDataProgress(20); + assertThat(progressTracker.getLoadingDataProgressPercent(), equalTo(21)); + } + + public void testUpdateWritingResultsProgress_GivenLowerValueThanCurrentProgress() { + ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo")); + + progressTracker.updateWritingResultsProgress(30); + + progressTracker.updateWritingResultsProgress(31); + assertThat(progressTracker.getWritingResultsProgressPercent(), equalTo(31)); + + progressTracker.updateWritingResultsProgress(30); + assertThat(progressTracker.getWritingResultsProgressPercent(), equalTo(31)); + } + + public void testUpdatePhase_GivenLowerValueThanCurrentProgress() { + ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo")); + + progressTracker.updatePhase(new PhaseProgress("foo", 40)); + + progressTracker.updatePhase(new PhaseProgress("foo", 41)); + assertThat(getProgressForPhase(progressTracker, "foo"), equalTo(41)); + + progressTracker.updatePhase(new PhaseProgress("foo", 40)); + assertThat(getProgressForPhase(progressTracker, "foo"), equalTo(41)); + } + + private static int getProgressForPhase(ProgressTracker progressTracker, String phase) { + return progressTracker.report().stream().filter(p -> p.getPhase().equals(phase)).findFirst().get().getProgressPercent(); + } }