[7.x][ML] Ensure phase progress may only increase (#56339) (#56357)

Due to multi-threading it is possible that phase progress
updates written from the c++ process arrive reordered.
We can address this by ensuring that progress may only increase.

Closes #56282

Backport of #56339
This commit is contained in:
Dimitris Athanasiou 2020-05-07 19:46:58 +03:00 committed by GitHub
parent 8f4af292a7
commit d064eda2b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 70 additions and 4 deletions

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.utils; package org.elasticsearch.xpack.core.ml.utils;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
@ -80,4 +81,9 @@ public class PhaseProgress implements ToXContentObject, Writeable {
PhaseProgress that = (PhaseProgress) o; PhaseProgress that = (PhaseProgress) o;
return Objects.equals(phase, that.phase) && progressPercent == that.progressPercent; return Objects.equals(phase, that.phase) && progressPercent == that.progressPercent;
} }
@Override
public String toString() {
return Strings.toString(this);
}
} }

View File

@ -54,7 +54,7 @@ public class ProgressTracker {
} }
public void updateReindexingProgress(int progressPercent) { public void updateReindexingProgress(int progressPercent) {
progressPercentPerPhase.put(REINDEXING, progressPercent); updatePhase(REINDEXING, progressPercent);
} }
public int getReindexingProgressPercent() { public int getReindexingProgressPercent() {
@ -62,11 +62,15 @@ public class ProgressTracker {
} }
public void updateLoadingDataProgress(int progressPercent) { public void updateLoadingDataProgress(int progressPercent) {
progressPercentPerPhase.put(LOADING_DATA, progressPercent); updatePhase(LOADING_DATA, progressPercent);
}
public int getLoadingDataProgressPercent() {
return progressPercentPerPhase.get(LOADING_DATA);
} }
public void updateWritingResultsProgress(int progressPercent) { public void updateWritingResultsProgress(int progressPercent) {
progressPercentPerPhase.put(WRITING_RESULTS, progressPercent); updatePhase(WRITING_RESULTS, progressPercent);
} }
public int getWritingResultsProgressPercent() { public int getWritingResultsProgressPercent() {
@ -74,7 +78,11 @@ public class ProgressTracker {
} }
public void updatePhase(PhaseProgress phase) { public void updatePhase(PhaseProgress phase) {
progressPercentPerPhase.computeIfPresent(phase.getPhase(), (k, v) -> phase.getProgressPercent()); updatePhase(phase.getPhase(), phase.getProgressPercent());
}
private void updatePhase(String phase, int progress) {
progressPercentPerPhase.computeIfPresent(phase, (k, v) -> Math.max(v, progress));
} }
public List<PhaseProgress> report() { public List<PhaseProgress> report() {

View File

@ -79,4 +79,56 @@ public class ProgressTrackerTests extends ESTestCase {
assertThat(phases.stream().map(PhaseProgress::getPhase).collect(Collectors.toList()), assertThat(phases.stream().map(PhaseProgress::getPhase).collect(Collectors.toList()),
contains("reindexing", "loading_data", "foo", "writing_results")); contains("reindexing", "loading_data", "foo", "writing_results"));
} }
public void testUpdateReindexingProgress_GivenLowerValueThanCurrentProgress() {
ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo"));
progressTracker.updateReindexingProgress(10);
progressTracker.updateReindexingProgress(11);
assertThat(progressTracker.getReindexingProgressPercent(), equalTo(11));
progressTracker.updateReindexingProgress(10);
assertThat(progressTracker.getReindexingProgressPercent(), equalTo(11));
}
public void testUpdateLoadingDataProgress_GivenLowerValueThanCurrentProgress() {
ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo"));
progressTracker.updateLoadingDataProgress(20);
progressTracker.updateLoadingDataProgress(21);
assertThat(progressTracker.getLoadingDataProgressPercent(), equalTo(21));
progressTracker.updateLoadingDataProgress(20);
assertThat(progressTracker.getLoadingDataProgressPercent(), equalTo(21));
}
public void testUpdateWritingResultsProgress_GivenLowerValueThanCurrentProgress() {
ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo"));
progressTracker.updateWritingResultsProgress(30);
progressTracker.updateWritingResultsProgress(31);
assertThat(progressTracker.getWritingResultsProgressPercent(), equalTo(31));
progressTracker.updateWritingResultsProgress(30);
assertThat(progressTracker.getWritingResultsProgressPercent(), equalTo(31));
}
public void testUpdatePhase_GivenLowerValueThanCurrentProgress() {
ProgressTracker progressTracker = ProgressTracker.fromZeroes(Collections.singletonList("foo"));
progressTracker.updatePhase(new PhaseProgress("foo", 40));
progressTracker.updatePhase(new PhaseProgress("foo", 41));
assertThat(getProgressForPhase(progressTracker, "foo"), equalTo(41));
progressTracker.updatePhase(new PhaseProgress("foo", 40));
assertThat(getProgressForPhase(progressTracker, "foo"), equalTo(41));
}
private static int getProgressForPhase(ProgressTracker progressTracker, String phase) {
return progressTracker.report().stream().filter(p -> p.getPhase().equals(phase)).findFirst().get().getProgressPercent();
}
} }