[7.x][ML] Improve progress reportings for DF analytics (#45856) (#45910)

Previously, the stats API reports a progress percentage
for DF analytics tasks that are running and are in the
`reindexing` or `analyzing` state.

This means that when the task is `stopped` there is no progress
reported. Thus, one cannot distinguish between a task that never
run to one that completed.

In addition, there are blind spots in the progress reporting.
In particular, we do not account for when data is loaded into the
process. We also do not account for when results are written.

This commit addresses the above issues. It changes progress
to being a list of objects, each one describing the phase
and its progress as a percentage. We currently have 4 phases:
reindexing, loading_data, analyzing, writing_results.

When the task stops, progress is persisted as a document in the
state index. The stats API now reports progress from in-memory
if the task is running, or returns the persisted document
(if there is one).
This commit is contained in:
Dimitris Athanasiou 2019-08-23 23:04:39 +03:00 committed by GitHub
parent b756e1b9be
commit be554fe5f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 989 additions and 129 deletions

View File

@ -28,6 +28,7 @@ import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
@ -42,17 +43,18 @@ public class DataFrameAnalyticsStats {
static final ParseField ID = new ParseField("id");
static final ParseField STATE = new ParseField("state");
static final ParseField FAILURE_REASON = new ParseField("failure_reason");
static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent");
static final ParseField PROGRESS = new ParseField("progress");
static final ParseField NODE = new ParseField("node");
static final ParseField ASSIGNMENT_EXPLANATION = new ParseField("assignment_explanation");
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<DataFrameAnalyticsStats, Void> PARSER =
new ConstructingObjectParser<>("data_frame_analytics_stats", true,
args -> new DataFrameAnalyticsStats(
(String) args[0],
(DataFrameAnalyticsState) args[1],
(String) args[2],
(Integer) args[3],
(List<PhaseProgress>) args[3],
(NodeAttributes) args[4],
(String) args[5]));
@ -65,7 +67,7 @@ public class DataFrameAnalyticsStats {
throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
}, STATE, ObjectParser.ValueType.STRING);
PARSER.declareString(optionalConstructorArg(), FAILURE_REASON);
PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT);
PARSER.declareObjectArray(optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS);
PARSER.declareObject(optionalConstructorArg(), NodeAttributes.PARSER, NODE);
PARSER.declareString(optionalConstructorArg(), ASSIGNMENT_EXPLANATION);
}
@ -73,17 +75,17 @@ public class DataFrameAnalyticsStats {
private final String id;
private final DataFrameAnalyticsState state;
private final String failureReason;
private final Integer progressPercent;
private final List<PhaseProgress> progress;
private final NodeAttributes node;
private final String assignmentExplanation;
public DataFrameAnalyticsStats(String id, DataFrameAnalyticsState state, @Nullable String failureReason,
@Nullable Integer progressPercent, @Nullable NodeAttributes node,
@Nullable List<PhaseProgress> progress, @Nullable NodeAttributes node,
@Nullable String assignmentExplanation) {
this.id = id;
this.state = state;
this.failureReason = failureReason;
this.progressPercent = progressPercent;
this.progress = progress;
this.node = node;
this.assignmentExplanation = assignmentExplanation;
}
@ -100,8 +102,8 @@ public class DataFrameAnalyticsStats {
return failureReason;
}
public Integer getProgressPercent() {
return progressPercent;
public List<PhaseProgress> getProgress() {
return progress;
}
public NodeAttributes getNode() {
@ -121,14 +123,14 @@ public class DataFrameAnalyticsStats {
return Objects.equals(id, other.id)
&& Objects.equals(state, other.state)
&& Objects.equals(failureReason, other.failureReason)
&& Objects.equals(progressPercent, other.progressPercent)
&& Objects.equals(progress, other.progress)
&& Objects.equals(node, other.node)
&& Objects.equals(assignmentExplanation, other.assignmentExplanation);
}
@Override
public int hashCode() {
return Objects.hash(id, state, failureReason, progressPercent, node, assignmentExplanation);
return Objects.hash(id, state, failureReason, progress, node, assignmentExplanation);
}
@Override
@ -137,7 +139,7 @@ public class DataFrameAnalyticsStats {
.add("id", id)
.add("state", state)
.add("failureReason", failureReason)
.add("progressPercent", progressPercent)
.add("progress", progress)
.add("node", node)
.add("assignmentExplanation", assignmentExplanation)
.toString();

View File

@ -0,0 +1,91 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.inject.internal.ToStringBuilder;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Objects;
/**
* A class that describes a phase and its progress as a percentage
*/
public class PhaseProgress implements ToXContentObject {
static final ParseField PHASE = new ParseField("phase");
static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent");
public static final ConstructingObjectParser<PhaseProgress, Void> PARSER = new ConstructingObjectParser<>("phase_progress",
true, a -> new PhaseProgress((String) a[0], (int) a[1]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), PHASE);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), PROGRESS_PERCENT);
}
private final String phase;
private final int progressPercent;
public PhaseProgress(String phase, int progressPercent) {
this.phase = Objects.requireNonNull(phase);
this.progressPercent = progressPercent;
}
public String getPhase() {
return phase;
}
public int getProgressPercent() {
return progressPercent;
}
@Override
public int hashCode() {
return Objects.hash(phase, progressPercent);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PhaseProgress that = (PhaseProgress) o;
return Objects.equals(phase, that.phase) && progressPercent == that.progressPercent;
}
@Override
public String toString() {
return new ToStringBuilder(getClass())
.add(PHASE.getPreferredName(), phase)
.add(PROGRESS_PERCENT.getPreferredName(), progressPercent)
.toString();
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(PhaseProgress.PHASE.getPreferredName(), phase);
builder.field(PhaseProgress.PROGRESS_PERCENT.getPreferredName(), progressPercent);
builder.endObject();
return builder;
}
}

View File

@ -123,6 +123,7 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.PhaseProgress;
import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
@ -1405,11 +1406,17 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
assertThat(stats.getId(), equalTo(configId));
assertThat(stats.getState(), equalTo(DataFrameAnalyticsState.STOPPED));
assertNull(stats.getFailureReason());
assertNull(stats.getProgressPercent());
assertNull(stats.getNode());
assertNull(stats.getAssignmentExplanation());
assertThat(statsResponse.getNodeFailures(), hasSize(0));
assertThat(statsResponse.getTaskFailures(), hasSize(0));
List<PhaseProgress> progress = stats.getProgress();
assertThat(progress, is(notNullValue()));
assertThat(progress.size(), equalTo(4));
assertThat(progress.get(0), equalTo(new PhaseProgress("reindexing", 0)));
assertThat(progress.get(1), equalTo(new PhaseProgress("loading_data", 0)));
assertThat(progress.get(2), equalTo(new PhaseProgress("analyzing", 0)));
assertThat(progress.get(3), equalTo(new PhaseProgress("writing_results", 0)));
}
public void testStartDataFrameAnalyticsConfig() throws Exception {

View File

@ -24,6 +24,8 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.test.ESTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester;
@ -44,11 +46,20 @@ public class DataFrameAnalyticsStatsTests extends ESTestCase {
randomAlphaOfLengthBetween(1, 10),
randomFrom(DataFrameAnalyticsState.values()),
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null : randomIntBetween(0, 100),
randomBoolean() ? null : createRandomProgress(),
randomBoolean() ? null : NodeAttributesTests.createRandom(),
randomBoolean() ? null : randomAlphaOfLengthBetween(1, 20));
}
private static List<PhaseProgress> createRandomProgress() {
int progressPhaseCount = randomIntBetween(3, 7);
List<PhaseProgress> progress = new ArrayList<>(progressPhaseCount);
for (int i = 0; i < progressPhaseCount; i++) {
progress.add(new PhaseProgress(randomAlphaOfLength(20), randomIntBetween(0, 100)));
}
return progress;
}
public static void toXContent(DataFrameAnalyticsStats stats, XContentBuilder builder) throws IOException {
builder.startObject();
builder.field(DataFrameAnalyticsStats.ID.getPreferredName(), stats.getId());
@ -56,8 +67,8 @@ public class DataFrameAnalyticsStatsTests extends ESTestCase {
if (stats.getFailureReason() != null) {
builder.field(DataFrameAnalyticsStats.FAILURE_REASON.getPreferredName(), stats.getFailureReason());
}
if (stats.getProgressPercent() != null) {
builder.field(DataFrameAnalyticsStats.PROGRESS_PERCENT.getPreferredName(), stats.getProgressPercent());
if (stats.getProgress() != null) {
builder.field(DataFrameAnalyticsStats.PROGRESS.getPreferredName(), stats.getProgress());
}
if (stats.getNode() != null) {
builder.field(DataFrameAnalyticsStats.NODE.getPreferredName(), stats.getNode());

View File

@ -0,0 +1,46 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.dataframe;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
public class PhaseProgressTests extends AbstractXContentTestCase<PhaseProgress> {
public static PhaseProgress createRandom() {
return new PhaseProgress(randomAlphaOfLength(20), randomIntBetween(0, 100));
}
@Override
protected PhaseProgress createTestInstance() {
return createRandom();
}
@Override
protected PhaseProgress doParseInstance(XContentParser parser) throws IOException {
return PhaseProgress.PARSER.apply(parser, null);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -99,7 +99,25 @@ The API returns the following results:
"data_frame_analytics": [
{
"id": "loganalytics",
"state": "stopped"
"state": "stopped",
"progress": [
{
"phase": "reindexing",
"progress_percent": 0
},
{
"phase": "loading_data",
"progress_percent": 0
},
{
"phase": "analyzing",
"progress_percent": 0
},
{
"phase": "writing_results",
"progress_percent": 0
}
]
}
]
}

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequestBuilder;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
@ -28,8 +29,10 @@ import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@ -154,19 +157,23 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
private final DataFrameAnalyticsState state;
@Nullable
private final String failureReason;
@Nullable
private final Integer progressPercentage;
/**
* The progress is described as a list of each phase and its completeness percentage.
*/
private final List<PhaseProgress> progress;
@Nullable
private final DiscoveryNode node;
@Nullable
private final String assignmentExplanation;
public Stats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, @Nullable Integer progressPercentage,
public Stats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, List<PhaseProgress> progress,
@Nullable DiscoveryNode node, @Nullable String assignmentExplanation) {
this.id = Objects.requireNonNull(id);
this.state = Objects.requireNonNull(state);
this.failureReason = failureReason;
this.progressPercentage = progressPercentage;
this.progress = Objects.requireNonNull(progress);
this.node = node;
this.assignmentExplanation = assignmentExplanation;
}
@ -175,11 +182,47 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
id = in.readString();
state = DataFrameAnalyticsState.fromStream(in);
failureReason = in.readOptionalString();
progressPercentage = in.readOptionalInt();
if (in.getVersion().before(Version.V_7_4_0)) {
progress = readProgressFromLegacy(state, in);
} else {
progress = in.readList(PhaseProgress::new);
}
node = in.readOptionalWriteable(DiscoveryNode::new);
assignmentExplanation = in.readOptionalString();
}
private static List<PhaseProgress> readProgressFromLegacy(DataFrameAnalyticsState state, StreamInput in) throws IOException {
Integer legacyProgressPercent = in.readOptionalInt();
if (legacyProgressPercent == null) {
return Collections.emptyList();
}
int reindexingProgress = 0;
int loadingDataProgress = 0;
int analyzingProgress = 0;
switch (state) {
case ANALYZING:
reindexingProgress = 100;
loadingDataProgress = 100;
analyzingProgress = legacyProgressPercent;
break;
case REINDEXING:
reindexingProgress = legacyProgressPercent;
break;
case STARTED:
case STOPPED:
case STOPPING:
default:
return null;
}
return Arrays.asList(
new PhaseProgress("reindexing", reindexingProgress),
new PhaseProgress("loading_data", loadingDataProgress),
new PhaseProgress("analyzing", analyzingProgress),
new PhaseProgress("writing_results", 0));
}
public String getId() {
return id;
}
@ -188,6 +231,10 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
return state;
}
public List<PhaseProgress> getProgress() {
return progress;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
// TODO: Have callers wrap the content with an object as they choose rather than forcing it upon them
@ -204,8 +251,8 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
if (failureReason != null) {
builder.field("failure_reason", failureReason);
}
if (progressPercentage != null) {
builder.field("progress_percent", progressPercentage);
if (progress != null) {
builder.field("progress", progress);
}
if (node != null) {
builder.startObject("node");
@ -232,14 +279,43 @@ public class GetDataFrameAnalyticsStatsAction extends ActionType<GetDataFrameAna
out.writeString(id);
state.writeTo(out);
out.writeOptionalString(failureReason);
out.writeOptionalInt(progressPercentage);
if (out.getVersion().before(Version.V_7_4_0)) {
writeProgressToLegacy(out);
} else {
out.writeList(progress);
}
out.writeOptionalWriteable(node);
out.writeOptionalString(assignmentExplanation);
}
private void writeProgressToLegacy(StreamOutput out) throws IOException {
String targetPhase = null;
switch (state) {
case ANALYZING:
targetPhase = "analyzing";
break;
case REINDEXING:
targetPhase = "reindexing";
break;
case STARTED:
case STOPPED:
case STOPPING:
default:
break;
}
Integer legacyProgressPercent = null;
for (PhaseProgress phaseProgress : progress) {
if (phaseProgress.getPhase().equals(targetPhase)) {
legacyProgressPercent = phaseProgress.getProgressPercent();
}
}
out.writeOptionalInt(legacyProgressPercent);
}
@Override
public int hashCode() {
return Objects.hash(id, state, failureReason, progressPercentage, node, assignmentExplanation);
return Objects.hash(id, state, failureReason, progress, node, assignmentExplanation);
}
@Override

View File

@ -0,0 +1,83 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.utils;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Objects;
/**
* A class that describes a phase and its progress as a percentage
*/
public class PhaseProgress implements ToXContentObject, Writeable {
public static final ParseField PHASE = new ParseField("phase");
public static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent");
public static final ConstructingObjectParser<PhaseProgress, Void> PARSER = new ConstructingObjectParser<>("phase_progress",
true, a -> new PhaseProgress((String) a[0], (int) a[1]));
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), PHASE);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), PROGRESS_PERCENT);
}
private final String phase;
private final int progressPercent;
public PhaseProgress(String phase, int progressPercent) {
this.phase = Objects.requireNonNull(phase);
this.progressPercent = progressPercent;
}
public PhaseProgress(StreamInput in) throws IOException {
phase = in.readString();
progressPercent = in.readVInt();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(phase);
out.writeVInt(progressPercent);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(PHASE.getPreferredName(), phase);
builder.field(PROGRESS_PERCENT.getPreferredName(), progressPercent);
builder.endObject();
return builder;
}
public String getPhase() {
return phase;
}
public int getProgressPercent() {
return progressPercent;
}
@Override
public int hashCode() {
return Objects.hash(phase, progressPercent);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PhaseProgress that = (PhaseProgress) o;
return Objects.equals(phase, that.phase) && progressPercent == that.progressPercent;
}
}

View File

@ -11,9 +11,11 @@ import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.Response;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.IntStream;
public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireSerializingTestCase<Response> {
@ -22,10 +24,13 @@ public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireS
int listSize = randomInt(10);
List<Response.Stats> analytics = new ArrayList<>(listSize);
for (int j = 0; j < listSize; j++) {
Integer progressPercentage = randomBoolean() ? null : randomIntBetween(0, 100);
String failureReason = randomBoolean() ? null : randomAlphaOfLength(10);
int progressSize = randomIntBetween(2, 5);
List<PhaseProgress> progress = new ArrayList<>(progressSize);
IntStream.of(progressSize).forEach(progressIndex -> progress.add(
new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100))));
Response.Stats stats = new Response.Stats(DataFrameAnalyticsConfigTests.randomValidId(),
randomFrom(DataFrameAnalyticsState.values()), failureReason, progressPercentage, null, randomAlphaOfLength(20));
randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, null, randomAlphaOfLength(20));
analytics.add(stats);
}
return new Response(new QueryPage<>(analytics, analytics.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));

View File

@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.utils;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import java.io.IOException;
public class PhaseProgressTests extends AbstractSerializingTestCase<PhaseProgress> {
@Override
protected PhaseProgress createTestInstance() {
return createRandom();
}
public static PhaseProgress createRandom() {
return new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100));
}
@Override
protected PhaseProgress doParseInstance(XContentParser parser) throws IOException {
return PhaseProgress.PARSER.apply(parser, null);
}
@Override
protected Writeable.Reader<PhaseProgress> instanceReader() {
return PhaseProgress::new;
}
}

View File

@ -5,11 +5,11 @@
*/
package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
@ -22,14 +22,16 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
/**
* Base class of ML integration tests that use a native data_frame_analytics process
@ -46,7 +48,8 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
private void cleanUpAnalytics() {
for (DataFrameAnalyticsConfig config : analytics) {
try {
deleteAnalytics(config.getId());
assertThat(deleteAnalytics(config.getId()).isAcknowledged(), is(true));
assertThat(searchStoredProgress(config.getId()).getHits().getTotalHits().value, equalTo(0L));
} catch (Exception e) {
// ignore
}
@ -100,10 +103,6 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
return response.getResponse().results();
}
protected static String createJsonRecord(Map<String, Object> keyValueMap) throws IOException {
return Strings.toString(JsonXContent.contentBuilder().map(keyValueMap)) + "\n";
}
protected static DataFrameAnalyticsConfig buildOutlierDetectionAnalytics(String id, String[] sourceIndex, String destIndex,
@Nullable String resultsField) {
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder();
@ -121,6 +120,28 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
assertThat(stats.get(0).getState(), equalTo(state));
}
protected void assertProgress(String id, int reindexing, int loadingData, int analyzing, int writingResults) {
List<GetDataFrameAnalyticsStatsAction.Response.Stats> stats = getAnalyticsStats(id);
List<PhaseProgress> progress = stats.get(0).getProgress();
assertThat(stats.size(), equalTo(1));
assertThat(stats.get(0).getId(), equalTo(id));
assertThat(progress.size(), equalTo(4));
assertThat(progress.get(0).getPhase(), equalTo("reindexing"));
assertThat(progress.get(1).getPhase(), equalTo("loading_data"));
assertThat(progress.get(2).getPhase(), equalTo("analyzing"));
assertThat(progress.get(3).getPhase(), equalTo("writing_results"));
assertThat(progress.get(0).getProgressPercent(), equalTo(reindexing));
assertThat(progress.get(1).getProgressPercent(), equalTo(loadingData));
assertThat(progress.get(2).getProgressPercent(), equalTo(analyzing));
assertThat(progress.get(3).getProgressPercent(), equalTo(writingResults));
}
protected SearchResponse searchStoredProgress(String id) {
return client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setQuery(QueryBuilders.idsQuery().addIds(TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask.progressDocId(id)))
.get();
}
protected static DataFrameAnalyticsConfig buildRegressionAnalytics(String id, String[] sourceIndex, String destIndex,
@Nullable String resultsField, String dependentVariable) {
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder();

View File

@ -73,6 +73,7 @@ public class OutlierDetectionWithMissingFieldsIT extends MlNativeDataFrameAnalyt
putAnalytics(config);
assertState(id, DataFrameAnalyticsState.STOPPED);
assertProgress(id, 0, 0, 0, 0);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
@ -99,5 +100,8 @@ public class OutlierDetectionWithMissingFieldsIT extends MlNativeDataFrameAnalyt
assertThat(destDoc.containsKey("ml"), is(false));
}
}
assertProgress(id, 100, 100, 100, 100);
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
}
}

View File

@ -78,6 +78,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
putAnalytics(config);
assertState(id, DataFrameAnalyticsState.STOPPED);
assertProgress(id, 0, 0, 0, 0);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
@ -113,6 +114,9 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
}
}
assertThat(scoreOfOutlier, is(greaterThan(scoreOfNonOutlier)));
assertProgress(id, 100, 100, 100, 100);
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
}
public void testOutlierDetectionWithEnoughDocumentsToScroll() throws Exception {
@ -143,6 +147,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
putAnalytics(config);
assertState(id, DataFrameAnalyticsState.STOPPED);
assertProgress(id, 0, 0, 0, 0);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
@ -156,6 +161,9 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
.setTrackTotalHits(true)
.setQuery(QueryBuilders.existsQuery("custom_ml.outlier_score")).get();
assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) docCount));
assertProgress(id, 100, 100, 100, 100);
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
}
public void testOutlierDetectionWithMoreFieldsThanDocValueFieldLimit() throws Exception {
@ -201,6 +209,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
putAnalytics(config);
assertState(id, DataFrameAnalyticsState.STOPPED);
assertProgress(id, 0, 0, 0, 0);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
@ -224,6 +233,9 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
double outlierScore = (double) resultsObject.get("outlier_score");
assertThat(outlierScore, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
}
assertProgress(id, 100, 100, 100, 100);
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
}
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/43960")
@ -312,6 +324,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
putAnalytics(config);
assertState(id, DataFrameAnalyticsState.STOPPED);
assertProgress(id, 0, 0, 0, 0);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
@ -325,6 +338,9 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
.setTrackTotalHits(true)
.setQuery(QueryBuilders.existsQuery("ml.outlier_score")).get();
assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) bulkRequestBuilder.numberOfActions()));
assertProgress(id, 100, 100, 100, 100);
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
}
public void testOutlierDetectionWithPreExistingDestIndex() throws Exception {
@ -358,6 +374,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
putAnalytics(config);
assertState(id, DataFrameAnalyticsState.STOPPED);
assertProgress(id, 0, 0, 0, 0);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
@ -371,6 +388,9 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
.setTrackTotalHits(true)
.setQuery(QueryBuilders.existsQuery("ml.outlier_score")).get();
assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) bulkRequestBuilder.numberOfActions()));
assertProgress(id, 100, 100, 100, 100);
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
}
public void testRegressionWithNumericFeatureAndFewDocuments() throws Exception {
@ -406,6 +426,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
putAnalytics(config);
assertState(id, DataFrameAnalyticsState.STOPPED);
assertProgress(id, 0, 0, 0, 0);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
@ -438,6 +459,9 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
}
}
assertThat(resultsWithPrediction, greaterThan(0));
assertProgress(id, 100, 100, 100, 100);
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
}
public void testModelMemoryLimitLowerThanEstimatedMemoryUsage() {

View File

@ -5,15 +5,20 @@
*/
package org.elasticsearch.xpack.ml.action;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.DocWriteResponse;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.delete.DeleteAction;
import org.elasticsearch.action.delete.DeleteRequest;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.ParentTaskAssigningClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
@ -21,7 +26,14 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.reindex.AbstractBulkByScrollRequest;
import org.elasticsearch.index.reindex.BulkByScrollResponse;
import org.elasticsearch.index.reindex.DeleteByQueryAction;
import org.elasticsearch.index.reindex.DeleteByQueryRequest;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.MlTasks;
@ -30,7 +42,9 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
import org.elasticsearch.xpack.ml.utils.MlIndicesUtils;
import java.io.IOException;
@ -45,18 +59,22 @@ import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
public class TransportDeleteDataFrameAnalyticsAction
extends TransportMasterNodeAction<DeleteDataFrameAnalyticsAction.Request, AcknowledgedResponse> {
private static final Logger LOGGER = LogManager.getLogger(TransportDeleteDataFrameAnalyticsAction.class);
private final Client client;
private final MlMemoryTracker memoryTracker;
private final DataFrameAnalyticsConfigProvider configProvider;
@Inject
public TransportDeleteDataFrameAnalyticsAction(TransportService transportService, ClusterService clusterService,
ThreadPool threadPool, ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver, Client client,
MlMemoryTracker memoryTracker) {
MlMemoryTracker memoryTracker, DataFrameAnalyticsConfigProvider configProvider) {
super(DeleteDataFrameAnalyticsAction.NAME, transportService, clusterService, threadPool, actionFilters,
DeleteDataFrameAnalyticsAction.Request::new, indexNameExpressionResolver);
this.client = client;
this.memoryTracker = memoryTracker;
this.configProvider = configProvider;
}
@Override
@ -72,6 +90,12 @@ public class TransportDeleteDataFrameAnalyticsAction
@Override
protected void masterOperation(DeleteDataFrameAnalyticsAction.Request request, ClusterState state,
ActionListener<AcknowledgedResponse> listener) {
throw new UnsupportedOperationException("The task parameter is required");
}
@Override
protected void masterOperation(Task task, DeleteDataFrameAnalyticsAction.Request request, ClusterState state,
ActionListener<AcknowledgedResponse> listener) {
String id = request.getId();
PersistentTasksCustomMetaData tasks = state.getMetaData().custom(PersistentTasksCustomMetaData.TYPE);
DataFrameAnalyticsState taskState = MlTasks.getDataFrameAnalyticsState(id, tasks);
@ -81,25 +105,70 @@ public class TransportDeleteDataFrameAnalyticsAction
return;
}
TaskId taskId = new TaskId(clusterService.localNode().getId(), task.getId());
ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, taskId);
// We clean up the memory tracker on delete because there is no stop; the task stops by itself
memoryTracker.removeDataFrameAnalyticsJob(id);
// Step 2. Delete the config
ActionListener<BulkByScrollResponse> deleteStateHandler = ActionListener.wrap(
bulkByScrollResponse -> {
if (bulkByScrollResponse.isTimedOut()) {
LOGGER.warn("[{}] DeleteByQuery for state timed out", id);
}
if (bulkByScrollResponse.getBulkFailures().isEmpty() == false) {
LOGGER.warn("[{}] {} failures and {} conflicts encountered while runnint DeleteByQuery for state", id,
bulkByScrollResponse.getBulkFailures().size(), bulkByScrollResponse.getVersionConflicts());
for (BulkItemResponse.Failure failure : bulkByScrollResponse.getBulkFailures()) {
LOGGER.warn("[{}] DBQ failure: {}", id, failure);
}
}
deleteConfig(parentTaskClient, id, listener);
},
listener::onFailure
);
// Step 1. Delete state
ActionListener<DataFrameAnalyticsConfig> configListener = ActionListener.wrap(
config -> deleteState(parentTaskClient, id, deleteStateHandler),
listener::onFailure
);
// Step 1. Get the config to check if it exists
configProvider.get(id, configListener);
}
private void deleteConfig(ParentTaskAssigningClient parentTaskClient, String id, ActionListener<AcknowledgedResponse> listener) {
DeleteRequest deleteRequest = new DeleteRequest(AnomalyDetectorsIndex.configIndexName());
deleteRequest.id(DataFrameAnalyticsConfig.documentId(id));
deleteRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
executeAsyncWithOrigin(client, ML_ORIGIN, DeleteAction.INSTANCE, deleteRequest, ActionListener.wrap(
executeAsyncWithOrigin(parentTaskClient, ML_ORIGIN, DeleteAction.INSTANCE, deleteRequest, ActionListener.wrap(
deleteResponse -> {
if (deleteResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) {
listener.onFailure(ExceptionsHelper.missingDataFrameAnalytics(id));
return;
}
assert deleteResponse.getResult() == DocWriteResponse.Result.DELETED;
LOGGER.info("[{}] Deleted", id);
listener.onResponse(new AcknowledgedResponse(true));
},
listener::onFailure
));
}
private void deleteState(ParentTaskAssigningClient parentTaskClient, String analyticsId,
ActionListener<BulkByScrollResponse> listener) {
DeleteByQueryRequest request = new DeleteByQueryRequest(AnomalyDetectorsIndex.jobStateIndexPattern());
request.setQuery(QueryBuilders.idsQuery().addIds(
TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask.progressDocId(analyticsId)));
request.setIndicesOptions(MlIndicesUtils.addIgnoreUnavailable(IndicesOptions.lenientExpandOpen()));
request.setSlices(AbstractBulkByScrollRequest.AUTO_SLICES);
request.setAbortOnVersionConflict(false);
request.setRefresh(true);
executeAsyncWithOrigin(parentTaskClient, ML_ORIGIN, DeleteByQueryAction.INSTANCE, request, listener);
}
@Override
protected ClusterBlockException checkBlock(DeleteDataFrameAnalyticsAction.Request request, ClusterState state) {
return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);

View File

@ -7,24 +7,32 @@ package org.elasticsearch.xpack.ml.action;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ResourceNotFoundException;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskRequest;
import org.elasticsearch.action.search.MultiSearchAction;
import org.elasticsearch.action.search.MultiSearchRequest;
import org.elasticsearch.action.search.MultiSearchResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.index.reindex.BulkByScrollTask;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskResult;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.action.util.QueryPage;
@ -35,9 +43,14 @@ import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.R
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask;
import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager;
import org.elasticsearch.xpack.ml.dataframe.StoredProgress;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
@ -55,16 +68,14 @@ public class TransportGetDataFrameAnalyticsStatsAction
private static final Logger LOGGER = LogManager.getLogger(TransportGetDataFrameAnalyticsStatsAction.class);
private final Client client;
private final AnalyticsProcessManager analyticsProcessManager;
@Inject
public TransportGetDataFrameAnalyticsStatsAction(TransportService transportService, ClusterService clusterService, Client client,
ActionFilters actionFilters, AnalyticsProcessManager analyticsProcessManager) {
ActionFilters actionFilters) {
super(GetDataFrameAnalyticsStatsAction.NAME, clusterService, transportService, actionFilters,
GetDataFrameAnalyticsStatsAction.Request::new, GetDataFrameAnalyticsStatsAction.Response::new,
in -> new QueryPage<>(in, GetDataFrameAnalyticsStatsAction.Response.Stats::new), ThreadPool.Names.MANAGEMENT);
this.client = client;
this.analyticsProcessManager = analyticsProcessManager;
}
@Override
@ -86,7 +97,7 @@ public class TransportGetDataFrameAnalyticsStatsAction
ActionListener<QueryPage<Stats>> listener) {
LOGGER.debug("Get stats for running task [{}]", task.getParams().getId());
ActionListener<Integer> progressListener = ActionListener.wrap(
ActionListener<List<PhaseProgress>> progressListener = ActionListener.wrap(
progress -> {
Stats stats = buildStats(task.getParams().getId(), progress);
listener.onResponse(new QueryPage<>(Collections.singletonList(stats), 1,
@ -94,38 +105,14 @@ public class TransportGetDataFrameAnalyticsStatsAction
}, listener::onFailure
);
ClusterState clusterState = clusterService.state();
PersistentTasksCustomMetaData tasks = clusterState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE);
DataFrameAnalyticsState analyticsState = MlTasks.getDataFrameAnalyticsState(task.getParams().getId(), tasks);
// For a running task we report the progress associated with its current state
if (analyticsState == DataFrameAnalyticsState.REINDEXING) {
getReindexTaskProgress(task, progressListener);
} else {
progressListener.onResponse(analyticsProcessManager.getProgressPercent(task.getAllocationId()));
}
}
private void getReindexTaskProgress(DataFrameAnalyticsTask task, ActionListener<Integer> listener) {
TaskId reindexTaskId = new TaskId(clusterService.localNode().getId(), task.getReindexingTaskId());
GetTaskRequest getTaskRequest = new GetTaskRequest();
getTaskRequest.setTaskId(reindexTaskId);
client.admin().cluster().getTask(getTaskRequest, ActionListener.wrap(
taskResponse -> {
TaskResult taskResult = taskResponse.getTask();
BulkByScrollTask.Status taskStatus = (BulkByScrollTask.Status) taskResult.getTask().getStatus();
int progress = taskStatus.getTotal() == 0 ? 100 : (int) (taskStatus.getCreated() * 100.0 / taskStatus.getTotal());
listener.onResponse(progress);
ActionListener<Void> reindexingProgressListener = ActionListener.wrap(
aVoid -> {
progressListener.onResponse(task.getProgressTracker().report());
},
error -> {
if (error instanceof ResourceNotFoundException) {
// The task has either not started yet or has finished, thus it is better to respond null and not show progress at all
listener.onResponse(null);
} else {
listener.onFailure(error);
}
}
));
listener::onFailure
);
task.updateReindexTaskProgress(reindexingProgressListener);
}
@Override
@ -166,12 +153,27 @@ public class TransportGetDataFrameAnalyticsStatsAction
void gatherStatsForStoppedTasks(List<String> expandedIds, GetDataFrameAnalyticsStatsAction.Response runningTasksResponse,
ActionListener<GetDataFrameAnalyticsStatsAction.Response> listener) {
List<String> stoppedTasksIds = determineStoppedTasksIds(expandedIds, runningTasksResponse.getResponse().results());
List<Stats> stoppedTasksStats = stoppedTasksIds.stream().map(this::buildStatsForStoppedTask).collect(Collectors.toList());
List<Stats> allTasksStats = new ArrayList<>(runningTasksResponse.getResponse().results());
allTasksStats.addAll(stoppedTasksStats);
Collections.sort(allTasksStats, Comparator.comparing(Stats::getId));
listener.onResponse(new GetDataFrameAnalyticsStatsAction.Response(new QueryPage<>(
allTasksStats, allTasksStats.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD)));
if (stoppedTasksIds.isEmpty()) {
listener.onResponse(runningTasksResponse);
return;
}
searchStoredProgresses(stoppedTasksIds, ActionListener.wrap(
storedProgresses -> {
List<Stats> stoppedStats = new ArrayList<>(stoppedTasksIds.size());
for (int i = 0; i < stoppedTasksIds.size(); i++) {
String configId = stoppedTasksIds.get(i);
StoredProgress storedProgress = storedProgresses.get(i);
stoppedStats.add(buildStats(configId, storedProgress.get()));
}
List<Stats> allTasksStats = new ArrayList<>(runningTasksResponse.getResponse().results());
allTasksStats.addAll(stoppedStats);
Collections.sort(allTasksStats, Comparator.comparing(Stats::getId));
listener.onResponse(new GetDataFrameAnalyticsStatsAction.Response(new QueryPage<>(
allTasksStats, allTasksStats.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD)));
},
listener::onFailure
));
}
static List<String> determineStoppedTasksIds(List<String> expandedIds, List<Stats> runningTasksStats) {
@ -179,11 +181,52 @@ public class TransportGetDataFrameAnalyticsStatsAction
return expandedIds.stream().filter(id -> startedTasksIds.contains(id) == false).collect(Collectors.toList());
}
private GetDataFrameAnalyticsStatsAction.Response.Stats buildStatsForStoppedTask(String concreteAnalyticsId) {
return buildStats(concreteAnalyticsId, null);
private void searchStoredProgresses(List<String> configIds, ActionListener<List<StoredProgress>> listener) {
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
for (String configId : configIds) {
SearchRequest searchRequest = new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern());
searchRequest.indicesOptions(IndicesOptions.lenientExpandOpen());
searchRequest.source().size(1);
searchRequest.source().query(QueryBuilders.idsQuery().addIds(DataFrameAnalyticsTask.progressDocId(configId)));
multiSearchRequest.add(searchRequest);
}
executeAsyncWithOrigin(client, ML_ORIGIN, MultiSearchAction.INSTANCE, multiSearchRequest, ActionListener.wrap(
multiSearchResponse -> {
List<StoredProgress> progresses = new ArrayList<>(configIds.size());
for (MultiSearchResponse.Item itemResponse : multiSearchResponse.getResponses()) {
if (itemResponse.isFailure()) {
listener.onFailure(ExceptionsHelper.serverError(itemResponse.getFailureMessage(), itemResponse.getFailure()));
return;
} else {
SearchHit[] hits = itemResponse.getResponse().getHits().getHits();
if (hits.length == 0) {
progresses.add(new StoredProgress(new DataFrameAnalyticsTask.ProgressTracker().report()));
} else {
progresses.add(parseStoredProgress(hits[0]));
}
}
}
listener.onResponse(progresses);
},
e -> listener.onFailure(ExceptionsHelper.serverError("Error searching for stored progresses", e))
));
}
private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId, @Nullable Integer progressPercent) {
private StoredProgress parseStoredProgress(SearchHit hit) {
BytesReference source = hit.getSourceRef();
try (InputStream stream = source.streamInput();
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, stream)) {
StoredProgress storedProgress = StoredProgress.PARSER.apply(parser, null);
return storedProgress;
} catch (IOException e) {
LOGGER.error(new ParameterizedMessage("failed to parse progress from doc with it [{}]", hit.getId()), e);
return new StoredProgress(Collections.emptyList());
}
}
private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId, List<PhaseProgress> progress) {
ClusterState clusterState = clusterService.state();
PersistentTasksCustomMetaData tasks = clusterState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE);
PersistentTasksCustomMetaData.PersistentTask<?> analyticsTask = MlTasks.getDataFrameAnalyticsTask(concreteAnalyticsId, tasks);
@ -200,6 +243,6 @@ public class TransportGetDataFrameAnalyticsStatsAction
assignmentExplanation = analyticsTask.getAssignment().getExplanation();
}
return new GetDataFrameAnalyticsStatsAction.Response.Stats(
concreteAnalyticsId, analyticsState, failureReason, progressPercent, node, assignmentExplanation);
concreteAnalyticsId, analyticsState, failureReason, progress, node, assignmentExplanation);
}
}

View File

@ -15,10 +15,14 @@ import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskRequest;
import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.client.Client;
@ -34,7 +38,10 @@ import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.index.reindex.BulkByScrollTask;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.persistent.AllocatedPersistentTask;
@ -45,6 +52,7 @@ import org.elasticsearch.persistent.PersistentTasksExecutor;
import org.elasticsearch.persistent.PersistentTasksService;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskResult;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ClientHelper;
@ -52,6 +60,7 @@ import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.MlMetadata;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
@ -59,10 +68,13 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.core.watcher.watch.Payload;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager;
import org.elasticsearch.xpack.ml.dataframe.MappingsMerger;
import org.elasticsearch.xpack.ml.dataframe.SourceDestValidator;
import org.elasticsearch.xpack.ml.dataframe.StoredProgress;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
import org.elasticsearch.xpack.ml.job.JobNodeSelector;
@ -70,12 +82,16 @@ import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Predicate;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.core.ml.MlTasks.AWAITING_UPGRADE;
import static org.elasticsearch.xpack.ml.MachineLearning.MAX_OPEN_JOBS_PER_NODE;
@ -369,7 +385,9 @@ public class TransportStartDataFrameAnalyticsAction
private final StartDataFrameAnalyticsAction.TaskParams taskParams;
@Nullable
private volatile Long reindexingTaskId;
private volatile boolean isReindexingFinished;
private volatile boolean isStopping;
private final ProgressTracker progressTracker = new ProgressTracker();
public DataFrameAnalyticsTask(long id, String type, String action, TaskId parentTask, Map<String, String> headers,
Client client, ClusterService clusterService, DataFrameAnalyticsManager analyticsManager,
@ -389,22 +407,50 @@ public class TransportStartDataFrameAnalyticsAction
this.reindexingTaskId = reindexingTaskId;
}
@Nullable
public Long getReindexingTaskId() {
return reindexingTaskId;
public void setReindexingFinished() {
isReindexingFinished = true;
}
public boolean isStopping() {
return isStopping;
}
public ProgressTracker getProgressTracker() {
return progressTracker;
}
@Override
protected void onCancelled() {
stop(getReasonCancelled(), TimeValue.ZERO);
}
@Override
public void markAsCompleted() {
persistProgress(() -> super.markAsCompleted());
}
@Override
public void markAsFailed(Exception e) {
persistProgress(() -> super.markAsFailed(e));
}
public void stop(String reason, TimeValue timeout) {
isStopping = true;
ActionListener<Void> reindexProgressListener = ActionListener.wrap(
aVoid -> doStop(reason, timeout),
e -> {
LOGGER.error(new ParameterizedMessage("[{}] Error updating reindexing progress", taskParams.getId()), e);
// We should log the error but it shouldn't stop us from stopping the task
doStop(reason, timeout);
}
);
// We need to update reindexing progress before we cancel the task
updateReindexTaskProgress(reindexProgressListener);
}
private void doStop(String reason, TimeValue timeout) {
if (reindexingTaskId != null) {
cancelReindexingTask(reason, timeout);
}
@ -440,10 +486,115 @@ public class TransportStartDataFrameAnalyticsAction
DataFrameAnalyticsTaskState newTaskState = new DataFrameAnalyticsTaskState(state, getAllocationId(), reason);
updatePersistentTaskState(newTaskState, ActionListener.wrap(
updatedTask -> LOGGER.info("[{}] Successfully update task state to [{}]", getParams().getId(), state),
e -> LOGGER.error(new ParameterizedMessage("[{}] Could not update task state to [{}]",
getParams().getId(), state), e)
e -> LOGGER.error(new ParameterizedMessage("[{}] Could not update task state to [{}] with reason [{}]",
getParams().getId(), state, reason), e)
));
}
public void updateReindexTaskProgress(ActionListener<Void> listener) {
TaskId reindexTaskId = getReindexTaskId();
if (reindexTaskId == null) {
// The task is not present which means either it has not started yet or it finished.
// We keep track of whether the task has finished so we can use that to tell whether the progress 100.
if (isReindexingFinished) {
progressTracker.reindexingPercent.set(100);
}
listener.onResponse(null);
return;
}
GetTaskRequest getTaskRequest = new GetTaskRequest();
getTaskRequest.setTaskId(reindexTaskId);
client.admin().cluster().getTask(getTaskRequest, ActionListener.wrap(
taskResponse -> {
TaskResult taskResult = taskResponse.getTask();
BulkByScrollTask.Status taskStatus = (BulkByScrollTask.Status) taskResult.getTask().getStatus();
int progress = taskStatus.getTotal() == 0 ? 0 : (int) (taskStatus.getCreated() * 100.0 / taskStatus.getTotal());
progressTracker.reindexingPercent.set(progress);
listener.onResponse(null);
},
error -> {
if (error instanceof ResourceNotFoundException) {
// The task is not present which means either it has not started yet or it finished.
// We keep track of whether the task has finished so we can use that to tell whether the progress 100.
if (isReindexingFinished) {
progressTracker.reindexingPercent.set(100);
}
listener.onResponse(null);
} else {
listener.onFailure(error);
}
}
));
}
@Nullable
private TaskId getReindexTaskId() {
try {
return new TaskId(clusterService.localNode().getId(), reindexingTaskId);
} catch (NullPointerException e) {
// This may happen if there is no reindexing task id set which means we either never started the task yet or we're finished
return null;
}
}
private void persistProgress(Runnable runnable) {
GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(taskParams.getId());
executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, ActionListener.wrap(
statsResponse -> {
GetDataFrameAnalyticsStatsAction.Response.Stats stats = statsResponse.getResponse().results().get(0);
IndexRequest indexRequest = new IndexRequest(AnomalyDetectorsIndex.jobStateIndexWriteAlias());
indexRequest.id(progressDocId(taskParams.getId()));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
try (XContentBuilder jsonBuilder = JsonXContent.contentBuilder()) {
new StoredProgress(stats.getProgress()).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
indexRequest.source(jsonBuilder);
}
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, ActionListener.wrap(
indexResponse -> {
LOGGER.debug("[{}] Successfully indexed progress document", taskParams.getId());
runnable.run();
},
indexError -> {
LOGGER.error(new ParameterizedMessage(
"[{}] cannot persist progress as an error occurred while indexing", taskParams.getId()), indexError);
runnable.run();
}
));
},
e -> {
LOGGER.error(new ParameterizedMessage(
"[{}] cannot persist progress as an error occurred while retrieving stats", taskParams.getId()), e);
runnable.run();
}
));
}
public static String progressDocId(String id) {
return "data_frame_analytics-" + id + "-progress";
}
public static class ProgressTracker {
public static final String REINDEXING = "reindexing";
public static final String LOADING_DATA = "loading_data";
public static final String ANALYZING = "analyzing";
public static final String WRITING_RESULTS = "writing_results";
public final AtomicInteger reindexingPercent = new AtomicInteger(0);
public final AtomicInteger loadingDataPercent = new AtomicInteger(0);
public final AtomicInteger analyzingPercent = new AtomicInteger(0);
public final AtomicInteger writingResultsPercent = new AtomicInteger(0);
public List<PhaseProgress> report() {
return Arrays.asList(
new PhaseProgress(REINDEXING, reindexingPercent.get()),
new PhaseProgress(LOADING_DATA, loadingDataPercent.get()),
new PhaseProgress(ANALYZING, analyzingPercent.get()),
new PhaseProgress(WRITING_RESULTS, writingResultsPercent.get())
);
}
}
}
static List<String> verifyIndicesPrimaryShardsAreActive(ClusterState clusterState, String... indexNames) {

View File

@ -126,6 +126,10 @@ public class DataFrameAnalyticsManager {
// Reindexing is complete; start analytics
ActionListener<RefreshResponse> refreshListener = ActionListener.wrap(
refreshResponse -> {
if (task.isStopping()) {
LOGGER.debug("[{}] Stopping before starting analytics process", config.getId());
return;
}
task.setReindexingTaskId(null);
startAnalytics(task, config, false);
},
@ -134,12 +138,18 @@ public class DataFrameAnalyticsManager {
// Refresh to ensure copied index is fully searchable
ActionListener<BulkByScrollResponse> reindexCompletedListener = ActionListener.wrap(
bulkResponse ->
bulkResponse -> {
if (task.isStopping()) {
LOGGER.debug("[{}] Stopping before refreshing destination index", config.getId());
return;
}
task.setReindexingFinished();
ClientHelper.executeAsyncWithOrigin(client,
ClientHelper.ML_ORIGIN,
RefreshAction.INSTANCE,
new RefreshRequest(config.getDest().getIndex()),
refreshListener),
refreshListener);
},
error -> task.updateState(DataFrameAnalyticsState.FAILED, error.getMessage())
);
@ -187,6 +197,9 @@ public class DataFrameAnalyticsManager {
}
private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, boolean isTaskRestarting) {
// Ensure we mark reindexing is finished for the case we are recovering a task that had finished reindexing
task.setReindexingFinished();
// Update state to ANALYZING and start process
ActionListener<DataFrameDataExtractorFactory> dataExtractorFactoryListener = ActionListener.wrap(
dataExtractorFactory -> {

View File

@ -0,0 +1,60 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
public class StoredProgress implements ToXContentObject {
private static final ParseField PROGRESS = new ParseField("progress");
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<StoredProgress, Void> PARSER = new ConstructingObjectParser<>(
PROGRESS.getPreferredName(), true, a -> new StoredProgress((List<PhaseProgress>) a[0]));
static {
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), PhaseProgress.PARSER, PROGRESS);
}
private final List<PhaseProgress> progress;
public StoredProgress(List<PhaseProgress> progress) {
this.progress = Objects.requireNonNull(progress);
}
public List<PhaseProgress> get() {
return progress;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(PROGRESS.getPreferredName(), progress);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || o.getClass().equals(getClass()) == false) return false;
StoredProgress that = (StoredProgress) o;
return Objects.equals(progress, that.progress);
}
@Override
public int hashCode() {
return Objects.hash(progress);
}
}

View File

@ -31,4 +31,10 @@ public interface AnalyticsProcess<ProcessResult> extends NativeProcess {
* a SIGPIPE
*/
void consumeAndCloseOutputStream();
/**
*
* @return the process config
*/
AnalyticsProcessConfig getConfig();
}

View File

@ -43,6 +43,10 @@ public class AnalyticsProcessConfig implements ToXContentObject {
this.analysis = Objects.requireNonNull(analysis);
}
public long rows() {
return rows;
}
public int cols() {
return cols;
}

View File

@ -11,7 +11,6 @@ import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
@ -31,7 +30,6 @@ import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
public class AnalyticsProcessManager {
@ -90,14 +88,15 @@ public class AnalyticsProcessManager {
Consumer<Exception> finishHandler) {
try {
ProcessContext processContext = processContextByAllocation.get(task.getAllocationId());
writeHeaderRecord(dataExtractor, process);
writeDataRows(dataExtractor, process);
writeDataRows(dataExtractor, process, task.getProgressTracker());
process.writeEndOfDataMessage();
process.flushStream();
LOGGER.info("[{}] Waiting for result processor to complete", config.getId());
resultProcessor.awaitForCompletion();
processContextByAllocation.get(task.getAllocationId()).setFailureReason(resultProcessor.getFailure());
processContext.setFailureReason(resultProcessor.getFailure());
refreshDest(config);
LOGGER.info("[{}] Result processor has completed", config.getId());
@ -122,12 +121,16 @@ public class AnalyticsProcessManager {
}
}
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process) throws IOException {
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
DataFrameAnalyticsTask.ProgressTracker progressTracker) throws IOException {
// The extra fields are for the doc hash and the control field (should be an empty string)
String[] record = new String[dataExtractor.getFieldNames().size() + 2];
// The value of the control field should be an empty string for data frame rows
record[record.length - 1] = "";
long totalRows = process.getConfig().rows();
long rowsProcessed = 0;
while (dataExtractor.hasNext()) {
Optional<List<DataFrameDataExtractor.Row>> rows = dataExtractor.next();
if (rows.isPresent()) {
@ -139,6 +142,8 @@ public class AnalyticsProcessManager {
process.writeRecord(record);
}
}
rowsProcessed += rows.get().size();
progressTracker.loadingDataPercent.set(rowsProcessed >= totalRows ? 100 : (int) (rowsProcessed * 100.0 / totalRows));
}
}
}
@ -179,12 +184,6 @@ public class AnalyticsProcessManager {
};
}
@Nullable
public Integer getProgressPercent(long allocationId) {
ProcessContext processContext = processContextByAllocation.get(allocationId);
return processContext == null ? null : processContext.progressPercent.get();
}
private void refreshDest(DataFrameAnalyticsConfig config) {
ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client,
() -> client.execute(RefreshAction.INSTANCE, new RefreshRequest(config.getDest().getIndex())).actionGet());
@ -222,7 +221,6 @@ public class AnalyticsProcessManager {
private volatile AnalyticsProcess<AnalyticsResult> process;
private volatile DataFrameDataExtractor dataExtractor;
private volatile AnalyticsResultProcessor resultProcessor;
private final AtomicInteger progressPercent = new AtomicInteger(0);
private volatile boolean processKilled;
private volatile String failureReason;
@ -238,10 +236,6 @@ public class AnalyticsProcessManager {
return processKilled;
}
void setProgressPercent(int progressPercent) {
this.progressPercent.set(progressPercent);
}
private synchronized void setFailureReason(String failureReason) {
// Only set the new reason if there isn't one already as we want to keep the first reason
if (failureReason != null) {
@ -282,7 +276,7 @@ public class AnalyticsProcessManager {
process = createProcess(task, createProcessConfig(config, dataExtractor));
DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client,
dataExtractorFactory.newExtractor(true));
resultProcessor = new AnalyticsResultProcessor(id, dataFrameRowsJoiner, this::isProcessKilled, this::setProgressPercent);
resultProcessor = new AnalyticsResultProcessor(id, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker());
return true;
}

View File

@ -9,13 +9,13 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import java.util.Iterator;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.function.Consumer;
import java.util.function.Supplier;
public class AnalyticsResultProcessor {
@ -25,16 +25,16 @@ public class AnalyticsResultProcessor {
private final String dataFrameAnalyticsId;
private final DataFrameRowsJoiner dataFrameRowsJoiner;
private final Supplier<Boolean> isProcessKilled;
private final Consumer<Integer> progressConsumer;
private final ProgressTracker progressTracker;
private final CountDownLatch completionLatch = new CountDownLatch(1);
private volatile String failure;
public AnalyticsResultProcessor(String dataFrameAnalyticsId, DataFrameRowsJoiner dataFrameRowsJoiner, Supplier<Boolean> isProcessKilled,
Consumer<Integer> progressConsumer) {
ProgressTracker progressTracker) {
this.dataFrameAnalyticsId = Objects.requireNonNull(dataFrameAnalyticsId);
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
this.isProcessKilled = Objects.requireNonNull(isProcessKilled);
this.progressConsumer = Objects.requireNonNull(progressConsumer);
this.progressTracker = Objects.requireNonNull(progressTracker);
}
@Nullable
@ -52,12 +52,25 @@ public class AnalyticsResultProcessor {
}
public void process(AnalyticsProcess<AnalyticsResult> process) {
long totalRows = process.getConfig().rows();
LOGGER.info("Total rows = {}", totalRows);
long processedRows = 0;
// TODO When java 9 features can be used, we will not need the local variable here
try (DataFrameRowsJoiner resultsJoiner = dataFrameRowsJoiner) {
Iterator<AnalyticsResult> iterator = process.readAnalyticsResults();
while (iterator.hasNext()) {
AnalyticsResult result = iterator.next();
processResult(result, resultsJoiner);
if (result.getRowResults() != null) {
processedRows++;
progressTracker.writingResultsPercent.set(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows));
}
}
if (isProcessKilled.get() == false) {
// This means we completed successfully so we need to set the progress to 100.
// This is because due to skipped rows, it is possible the processed rows will not reach the total rows.
progressTracker.writingResultsPercent.set(100);
}
} catch (Exception e) {
if (isProcessKilled.get()) {
@ -79,7 +92,7 @@ public class AnalyticsResultProcessor {
}
Integer progressPercent = result.getProgressPercent();
if (progressPercent != null) {
progressConsumer.accept(progressPercent);
progressTracker.analyzingPercent.set(progressPercent);
}
}
}

View File

@ -6,21 +6,54 @@
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.process.ProcessResultsParser;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Path;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;
public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess<AnalyticsResult> {
private static final String NAME = "analytics";
protected NativeAnalyticsProcess(String jobId, InputStream logStream, OutputStream processInStream,
InputStream processOutStream, OutputStream processRestoreStream, int numberOfFields,
List<Path> filesToDelete, Consumer<String> onProcessCrash) {
private final ProcessResultsParser<AnalyticsResult> resultsParser = new ProcessResultsParser<>(AnalyticsResult.PARSER);
private final AnalyticsProcessConfig config;
protected NativeAnalyticsProcess(String jobId, InputStream logStream, OutputStream processInStream, InputStream processOutStream,
OutputStream processRestoreStream, int numberOfFields, List<Path> filesToDelete,
Consumer<String> onProcessCrash, AnalyticsProcessConfig config) {
super(NAME, AnalyticsResult.PARSER, jobId, logStream, processInStream, processOutStream, processRestoreStream, numberOfFields,
filesToDelete, onProcessCrash);
this.config = Objects.requireNonNull(config);
}
@Override
public String getName() {
return NAME;
}
@Override
public void persistState() {
// Nothing to persist
}
@Override
public void writeEndOfDataMessage() throws IOException {
new AnalyticsControlMessageWriter(recordWriter(), numberOfFields()).writeEndOfData();
}
@Override
public Iterator<AnalyticsResult> readAnalyticsResults() {
return resultsParser.parseResults(processOutStream());
}
@Override
public AnalyticsProcessConfig getConfig() {
return config;
}
}

View File

@ -64,7 +64,7 @@ public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory<An
NativeAnalyticsProcess analyticsProcess = new NativeAnalyticsProcess(jobId, processPipes.getLogStream().get(),
processPipes.getProcessInStream().get(), processPipes.getProcessOutStream().get(), null, numberOfFields,
filesToDelete, onProcessCrash);
filesToDelete, onProcessCrash, analyticsProcessConfig);
try {
analyticsProcess.start(executorService);

View File

@ -24,4 +24,9 @@ public class NativeMemoryUsageEstimationProcess extends AbstractNativeAnalyticsP
super(NAME, MemoryUsageEstimationResult.PARSER, jobId, logStream, processInStream, processOutStream, processRestoreStream,
numberOfFields, filesToDelete, onProcessCrash);
}
@Override
public AnalyticsProcessConfig getConfig() {
throw new UnsupportedOperationException();
}
}

View File

@ -0,0 +1,37 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class StoredProgressTests extends AbstractXContentTestCase<StoredProgress> {
@Override
protected StoredProgress doParseInstance(XContentParser parser) throws IOException {
return StoredProgress.PARSER.apply(parser, null);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected StoredProgress createTestInstance() {
int phaseCount = randomIntBetween(3, 7);
List<PhaseProgress> progress = new ArrayList<>(phaseCount);
for (int i = 0; i < phaseCount; i++) {
progress.add(new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)));
}
return new StoredProgress(progress);
}
}

View File

@ -5,7 +5,10 @@
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.junit.Before;
@ -28,8 +31,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
private AnalyticsProcess<AnalyticsResult> process;
private DataFrameRowsJoiner dataFrameRowsJoiner;
private int progressPercent;
private ProgressTracker progressTracker = new ProgressTracker();
@Before
@SuppressWarnings("unchecked")
@ -39,6 +41,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
}
public void testProcess_GivenNoResults() {
givenDataFrameRows(0);
givenProcessResults(Collections.emptyList());
AnalyticsResultProcessor resultProcessor = createResultProcessor();
@ -50,6 +53,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
}
public void testProcess_GivenEmptyResults() {
givenDataFrameRows(2);
givenProcessResults(Arrays.asList(new AnalyticsResult(null, 50), new AnalyticsResult(null, 100)));
AnalyticsResultProcessor resultProcessor = createResultProcessor();
@ -58,10 +62,11 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
verify(dataFrameRowsJoiner).close();
Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner);
assertThat(progressPercent, equalTo(100));
assertThat(progressTracker.writingResultsPercent.get(), equalTo(100));
}
public void testProcess_GivenRowResults() {
givenDataFrameRows(2);
RowResults rowResults1 = mock(RowResults.class);
RowResults rowResults2 = mock(RowResults.class);
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50), new AnalyticsResult(rowResults2, 100)));
@ -74,15 +79,20 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults1);
inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults2);
assertThat(progressPercent, equalTo(100));
assertThat(progressTracker.writingResultsPercent.get(), equalTo(100));
}
private void givenProcessResults(List<AnalyticsResult> results) {
when(process.readAnalyticsResults()).thenReturn(results.iterator());
}
private void givenDataFrameRows(int rows) {
AnalyticsProcessConfig config = new AnalyticsProcessConfig(
rows, 1, ByteSizeValue.ZERO, 1, "ml", Collections.emptySet(), mock(DataFrameAnalysis.class));
when(process.getConfig()).thenReturn(config);
}
private AnalyticsResultProcessor createResultProcessor() {
return new AnalyticsResultProcessor(JOB_ID, dataFrameRowsJoiner, () -> false,
progressPercent -> this.progressPercent = progressPercent);
return new AnalyticsResultProcessor(JOB_ID, dataFrameRowsJoiner, () -> false, progressTracker);
}
}