[7.x] Implement ml/data_frame/analytics/_estimate_memory_usage API endpoint (#45188) (#45510)

This commit is contained in:
Przemysław Witek 2019-08-14 08:26:03 +02:00 committed by GitHub
parent 84bf98e9cd
commit df574e5168
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 1882 additions and 190 deletions

View File

@ -0,0 +1,83 @@
[role="xpack"]
[testenv="platinum"]
[[estimate-memory-usage-dfanalytics]]
=== Estimate memory usage API
[subs="attributes"]
++++
<titleabbrev>Estimate memory usage for {dfanalytics-jobs}</titleabbrev>
++++
Estimates memory usage for the given {dataframe-analytics-config}.
experimental[]
[[ml-estimate-memory-usage-dfanalytics-request]]
==== {api-request-title}
`POST _ml/data_frame/analytics/_estimate_memory_usage`
[[ml-estimate-memory-usage-dfanalytics-prereq]]
==== {api-prereq-title}
* You must have `monitor_ml` privilege to use this API. For more
information, see {stack-ov}/security-privileges.html[Security privileges] and
{stack-ov}/built-in-roles.html[Built-in roles].
[[ml-estimate-memory-usage-dfanalytics-desc]]
==== {api-description-title}
This API estimates memory usage for the given {dataframe-analytics-config} before the {dfanalytics-job} is even created.
Serves as an advice on how to set `model_memory_limit` when creating {dfanalytics-job}.
[[ml-estimate-memory-usage-dfanalytics-request-body]]
==== {api-request-body-title}
`data_frame_analytics_config`::
(Required, object) Intended configuration of {dfanalytics-job}. For more information, see
<<ml-dfanalytics-resources>>.
Note that `id` and `dest` don't need to be provided in the context of this API.
[[ml-estimate-memory-usage-dfanalytics-results]]
==== {api-response-body-title}
`expected_memory_usage_with_one_partition`::
(string) Estimated memory usage under the assumption that the whole {dfanalytics} should happen in memory
(i.e. without overflowing to disk).
`expected_memory_usage_with_max_partitions`::
(string) Estimated memory usage under the assumption that overflowing to disk is allowed during {dfanalytics}.
`expected_memory_usage_with_max_partitions` is usually smaller than `expected_memory_usage_with_one_partition`
as using disk allows to limit the main memory needed to perform {dfanalytics}.
[[ml-estimate-memory-usage-dfanalytics-example]]
==== {api-examples-title}
[source,js]
--------------------------------------------------
POST _ml/data_frame/analytics/_estimate_memory_usage
{
"data_frame_analytics_config": {
"source": {
"index": "logdata"
},
"analysis": {
"outlier_detection": {}
}
}
}
--------------------------------------------------
// CONSOLE
// TEST[skip:TBD]
The API returns the following results:
[source,js]
----
{
"expected_memory_usage_with_one_partition": "128MB",
"expected_memory_usage_with_max_partitions": "32MB"
}
----
// TESTRESPONSE

View File

@ -12,6 +12,7 @@ You can use the following APIs to perform {ml} {dfanalytics} activities.
* <<start-dfanalytics,Start {dfanalytics-jobs}>>
* <<stop-dfanalytics,Stop {dfanalytics-jobs}>>
* <<evaluate-dfanalytics,Evaluate {dfanalytics}>>
* <<estimate-memory-usage-dfanalytics,Estimate memory usage for {dfanalytics}>>
See also <<ml-apis>>.
@ -21,10 +22,11 @@ include::put-dfanalytics.asciidoc[]
include::delete-dfanalytics.asciidoc[]
//EVALUATE
include::evaluate-dfanalytics.asciidoc[]
//ESTIMATE_MEMORY_USAGE
include::estimate-memory-usage-dfanalytics.asciidoc[]
//GET
include::get-dfanalytics.asciidoc[]
include::get-dfanalytics-stats.asciidoc[]
//SET/START/STOP
include::start-dfanalytics.asciidoc[]
include::stop-dfanalytics.asciidoc[]

View File

@ -96,6 +96,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteFilterAction;
import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction;
import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction;
import org.elasticsearch.xpack.core.ml.action.FindFileStructureAction;
@ -347,6 +348,7 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
StartDataFrameAnalyticsAction.INSTANCE,
StopDataFrameAnalyticsAction.INSTANCE,
EvaluateDataFrameAction.INSTANCE,
EstimateMemoryUsageAction.INSTANCE,
// security
ClearRealmCacheAction.INSTANCE,
ClearRolesCacheAction.INSTANCE,

View File

@ -0,0 +1,204 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class EstimateMemoryUsageAction extends ActionType<EstimateMemoryUsageAction.Response> {
public static final EstimateMemoryUsageAction INSTANCE = new EstimateMemoryUsageAction();
public static final String NAME = "cluster:admin/xpack/ml/data_frame/analytics/estimate_memory_usage";
private EstimateMemoryUsageAction() {
super(NAME, EstimateMemoryUsageAction.Response::new);
}
public static class Request extends ActionRequest implements ToXContentObject {
private static final ParseField DATA_FRAME_ANALYTICS_CONFIG = new ParseField("data_frame_analytics_config");
private static final ConstructingObjectParser<EstimateMemoryUsageAction.Request, Void> PARSER =
new ConstructingObjectParser<>(
NAME,
args -> {
DataFrameAnalyticsConfig.Builder configBuilder = (DataFrameAnalyticsConfig.Builder) args[0];
DataFrameAnalyticsConfig config = configBuilder.buildForMemoryEstimation();
return new EstimateMemoryUsageAction.Request(config);
});
static {
PARSER.declareObject(constructorArg(), DataFrameAnalyticsConfig.STRICT_PARSER, DATA_FRAME_ANALYTICS_CONFIG);
}
public static EstimateMemoryUsageAction.Request parseRequest(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final DataFrameAnalyticsConfig config;
public Request(DataFrameAnalyticsConfig config) {
this.config = ExceptionsHelper.requireNonNull(config, DATA_FRAME_ANALYTICS_CONFIG);
}
public Request(StreamInput in) throws IOException {
super(in);
this.config = new DataFrameAnalyticsConfig(in);
}
@Override
public ActionRequestValidationException validate() {
return null;
}
public DataFrameAnalyticsConfig getConfig() {
return config;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
config.writeTo(out);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(DATA_FRAME_ANALYTICS_CONFIG.getPreferredName(), config);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null || getClass() != other.getClass()) {
return false;
}
Request that = (Request) other;
return Objects.equals(config, that.config);
}
@Override
public int hashCode() {
return Objects.hash(config);
}
}
public static class Response extends ActionResponse implements ToXContentObject {
public static final ParseField TYPE = new ParseField("memory_usage_estimation_result");
public static final ParseField EXPECTED_MEMORY_USAGE_WITH_ONE_PARTITION =
new ParseField("expected_memory_usage_with_one_partition");
public static final ParseField EXPECTED_MEMORY_USAGE_WITH_MAX_PARTITIONS =
new ParseField("expected_memory_usage_with_max_partitions");
static final ConstructingObjectParser<Response, Void> PARSER =
new ConstructingObjectParser<>(
TYPE.getPreferredName(),
args -> new Response((ByteSizeValue) args[0], (ByteSizeValue) args[1]));
static {
PARSER.declareField(
optionalConstructorArg(),
(p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), EXPECTED_MEMORY_USAGE_WITH_ONE_PARTITION.getPreferredName()),
EXPECTED_MEMORY_USAGE_WITH_ONE_PARTITION,
ObjectParser.ValueType.VALUE);
PARSER.declareField(
optionalConstructorArg(),
(p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), EXPECTED_MEMORY_USAGE_WITH_MAX_PARTITIONS.getPreferredName()),
EXPECTED_MEMORY_USAGE_WITH_MAX_PARTITIONS,
ObjectParser.ValueType.VALUE);
}
private final ByteSizeValue expectedMemoryUsageWithOnePartition;
private final ByteSizeValue expectedMemoryUsageWithMaxPartitions;
public Response(@Nullable ByteSizeValue expectedMemoryUsageWithOnePartition,
@Nullable ByteSizeValue expectedMemoryUsageWithMaxPartitions) {
this.expectedMemoryUsageWithOnePartition = expectedMemoryUsageWithOnePartition;
this.expectedMemoryUsageWithMaxPartitions = expectedMemoryUsageWithMaxPartitions;
}
public Response(StreamInput in) throws IOException {
super(in);
this.expectedMemoryUsageWithOnePartition = in.readOptionalWriteable(ByteSizeValue::new);
this.expectedMemoryUsageWithMaxPartitions = in.readOptionalWriteable(ByteSizeValue::new);
}
public ByteSizeValue getExpectedMemoryUsageWithOnePartition() {
return expectedMemoryUsageWithOnePartition;
}
public ByteSizeValue getExpectedMemoryUsageWithMaxPartitions() {
return expectedMemoryUsageWithMaxPartitions;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalWriteable(expectedMemoryUsageWithOnePartition);
out.writeOptionalWriteable(expectedMemoryUsageWithMaxPartitions);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (expectedMemoryUsageWithOnePartition != null) {
builder.field(
EXPECTED_MEMORY_USAGE_WITH_ONE_PARTITION.getPreferredName(), expectedMemoryUsageWithOnePartition.getStringRep());
}
if (expectedMemoryUsageWithMaxPartitions != null) {
builder.field(
EXPECTED_MEMORY_USAGE_WITH_MAX_PARTITIONS.getPreferredName(), expectedMemoryUsageWithMaxPartitions.getStringRep());
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null || getClass() != other.getClass()) {
return false;
}
Response that = (Response) other;
return Objects.equals(expectedMemoryUsageWithOnePartition, that.expectedMemoryUsageWithOnePartition)
&& Objects.equals(expectedMemoryUsageWithMaxPartitions, that.expectedMemoryUsageWithMaxPartitions);
}
@Override
public int hashCode() {
return Objects.hash(expectedMemoryUsageWithOnePartition, expectedMemoryUsageWithMaxPartitions);
}
}
}

View File

@ -57,7 +57,7 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable {
public static final ObjectParser<Builder, Void> STRICT_PARSER = createParser(false);
public static final ObjectParser<Builder, Void> LENIENT_PARSER = createParser(true);
public static ObjectParser<Builder, Void> createParser(boolean ignoreUnknownFields) {
private static ObjectParser<Builder, Void> createParser(boolean ignoreUnknownFields) {
ObjectParser<Builder, Void> parser = new ObjectParser<>(TYPE, ignoreUnknownFields, Builder::new);
parser.declareString((c, s) -> {}, CONFIG_TYPE);
@ -281,14 +281,6 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable {
public Builder() {}
public Builder(String id) {
setId(id);
}
public Builder(ByteSizeValue maxModelMemoryLimit) {
this.maxModelMemoryLimit = maxModelMemoryLimit;
}
public Builder(DataFrameAnalyticsConfig config) {
this(config, null);
}
@ -343,30 +335,10 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable {
}
public Builder setModelMemoryLimit(ByteSizeValue modelMemoryLimit) {
if (modelMemoryLimit != null && modelMemoryLimit.compareTo(MIN_MODEL_MEMORY_LIMIT) < 0) {
throw new IllegalArgumentException("[" + MODEL_MEMORY_LIMIT.getPreferredName()
+ "] must be at least [" + MIN_MODEL_MEMORY_LIMIT.getStringRep() + "]");
}
this.modelMemoryLimit = modelMemoryLimit;
return this;
}
private void applyMaxModelMemoryLimit() {
boolean maxModelMemoryIsSet = maxModelMemoryLimit != null && maxModelMemoryLimit.getMb() > 0;
if (modelMemoryLimit == null) {
// Default is silently capped if higher than limit
if (maxModelMemoryIsSet && DEFAULT_MODEL_MEMORY_LIMIT.compareTo(maxModelMemoryLimit) > 0) {
modelMemoryLimit = maxModelMemoryLimit;
}
} else if (maxModelMemoryIsSet && modelMemoryLimit.compareTo(maxModelMemoryLimit) > 0) {
// Explicit setting higher than limit is an error
throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.JOB_CONFIG_MODEL_MEMORY_LIMIT_GREATER_THAN_MAX,
modelMemoryLimit, maxModelMemoryLimit));
}
}
public Builder setCreateTime(Instant createTime) {
this.createTime = createTime;
return this;
@ -377,9 +349,53 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable {
return this;
}
/**
* Builds {@link DataFrameAnalyticsConfig} object.
*/
public DataFrameAnalyticsConfig build() {
applyMaxModelMemoryLimit();
return new DataFrameAnalyticsConfig(id, source, dest, analysis, headers, modelMemoryLimit, analyzedFields, createTime, version);
}
/**
* Builds {@link DataFrameAnalyticsConfig} object for the purpose of performing memory estimation.
* Some fields (i.e. "id", "dest") may not be present, therefore we overwrite them here to make {@link DataFrameAnalyticsConfig}'s
* constructor validations happy.
*/
public DataFrameAnalyticsConfig buildForMemoryEstimation() {
return new DataFrameAnalyticsConfig(
id != null ? id : "dummy",
source,
dest != null ? dest : new DataFrameAnalyticsDest("dummy", null),
analysis,
headers,
modelMemoryLimit,
analyzedFields,
createTime,
version);
}
private void applyMaxModelMemoryLimit() {
boolean maxModelMemoryIsSet = maxModelMemoryLimit != null && maxModelMemoryLimit.getMb() > 0;
if (modelMemoryLimit != null) {
if (modelMemoryLimit.compareTo(MIN_MODEL_MEMORY_LIMIT) < 0) {
// Explicit setting lower than minimum is an error
throw ExceptionsHelper.badRequestException(
Messages.getMessage(Messages.JOB_CONFIG_MODEL_MEMORY_LIMIT_TOO_LOW, modelMemoryLimit));
}
if (maxModelMemoryIsSet && modelMemoryLimit.compareTo(maxModelMemoryLimit) > 0) {
// Explicit setting higher than limit is an error
throw ExceptionsHelper.badRequestException(
Messages.getMessage(
Messages.JOB_CONFIG_MODEL_MEMORY_LIMIT_GREATER_THAN_MAX, modelMemoryLimit, maxModelMemoryLimit));
}
} else {
// Default is silently capped if higher than limit
if (maxModelMemoryIsSet && DEFAULT_MODEL_MEMORY_LIMIT.compareTo(maxModelMemoryLimit) > 0) {
modelMemoryLimit = maxModelMemoryLimit;
}
}
}
}
}

View File

@ -122,7 +122,7 @@ public final class Messages {
"Invalid detector rule: scope field ''{0}'' is invalid; select from {1}";
public static final String JOB_CONFIG_FIELDNAME_INCOMPATIBLE_FUNCTION = "field_name cannot be used with function ''{0}''";
public static final String JOB_CONFIG_FIELD_VALUE_TOO_LOW = "{0} cannot be less than {1,number}. Value = {2,number}";
public static final String JOB_CONFIG_MODEL_MEMORY_LIMIT_TOO_LOW = "model_memory_limit must be at least 1 MiB. Value = {0,number}";
public static final String JOB_CONFIG_MODEL_MEMORY_LIMIT_TOO_LOW = "model_memory_limit must be at least 1 MiB. Value = {0}";
public static final String JOB_CONFIG_MODEL_MEMORY_LIMIT_GREATER_THAN_MAX =
"model_memory_limit [{0}] must be less than the value of the " +
MachineLearningField.MAX_MODEL_MEMORY_LIMIT.getKey() +

View File

@ -0,0 +1,55 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction.Request;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class EstimateMemoryUsageActionRequestTests extends AbstractSerializingTestCase<Request> {
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables());
return new NamedWriteableRegistry(namedWriteables);
}
@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
@Override
protected Request createTestInstance() {
return new Request(DataFrameAnalyticsConfigTests.createRandom("dummy"));
}
@Override
protected Writeable.Reader<Request> instanceReader() {
return Request::new;
}
@Override
protected Request doParseInstance(XContentParser parser) {
return Request.parseRequest(parser);
}
}

View File

@ -0,0 +1,47 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction.Response;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;
public class EstimateMemoryUsageActionResponseTests extends AbstractSerializingTestCase<Response> {
@Override
protected Response createTestInstance() {
return new Response(
randomBoolean() ? new ByteSizeValue(randomNonNegativeLong()) : null,
randomBoolean() ? new ByteSizeValue(randomNonNegativeLong()) : null);
}
@Override
protected Writeable.Reader<Response> instanceReader() {
return Response::new;
}
@Override
protected Response doParseInstance(XContentParser parser) {
return Response.PARSER.apply(parser, null);
}
public void testConstructor_NullValues() {
Response response = new Response(null, null);
assertThat(response.getExpectedMemoryUsageWithOnePartition(), nullValue());
assertThat(response.getExpectedMemoryUsageWithMaxPartitions(), nullValue());
}
public void testConstructor() {
Response response = new Response(new ByteSizeValue(2048), new ByteSizeValue(1024));
assertThat(response.getExpectedMemoryUsageWithOnePartition(), equalTo(new ByteSizeValue(2048)));
assertThat(response.getExpectedMemoryUsageWithMaxPartitions(), equalTo(new ByteSizeValue(1024)));
}
}

View File

@ -43,7 +43,6 @@ import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasSize;
@ -227,18 +226,18 @@ public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase<D
DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder();
// All these are different ways of specifying a limit that is lower than the minimum
assertTooSmall(expectThrows(IllegalArgumentException.class,
() -> builder.setModelMemoryLimit(new ByteSizeValue(1048575, ByteSizeUnit.BYTES))));
assertTooSmall(expectThrows(IllegalArgumentException.class,
() -> builder.setModelMemoryLimit(new ByteSizeValue(0, ByteSizeUnit.BYTES))));
assertTooSmall(expectThrows(IllegalArgumentException.class,
() -> builder.setModelMemoryLimit(new ByteSizeValue(-1, ByteSizeUnit.BYTES))));
assertTooSmall(expectThrows(IllegalArgumentException.class,
() -> builder.setModelMemoryLimit(new ByteSizeValue(1023, ByteSizeUnit.KB))));
assertTooSmall(expectThrows(IllegalArgumentException.class,
() -> builder.setModelMemoryLimit(new ByteSizeValue(0, ByteSizeUnit.KB))));
assertTooSmall(expectThrows(IllegalArgumentException.class,
() -> builder.setModelMemoryLimit(new ByteSizeValue(0, ByteSizeUnit.MB))));
assertTooSmall(expectThrows(ElasticsearchStatusException.class,
() -> builder.setModelMemoryLimit(new ByteSizeValue(1048575, ByteSizeUnit.BYTES)).build()));
assertTooSmall(expectThrows(ElasticsearchStatusException.class,
() -> builder.setModelMemoryLimit(new ByteSizeValue(0, ByteSizeUnit.BYTES)).build()));
assertTooSmall(expectThrows(ElasticsearchStatusException.class,
() -> builder.setModelMemoryLimit(new ByteSizeValue(-1, ByteSizeUnit.BYTES)).build()));
assertTooSmall(expectThrows(ElasticsearchStatusException.class,
() -> builder.setModelMemoryLimit(new ByteSizeValue(1023, ByteSizeUnit.KB)).build()));
assertTooSmall(expectThrows(ElasticsearchStatusException.class,
() -> builder.setModelMemoryLimit(new ByteSizeValue(0, ByteSizeUnit.KB)).build()));
assertTooSmall(expectThrows(ElasticsearchStatusException.class,
() -> builder.setModelMemoryLimit(new ByteSizeValue(0, ByteSizeUnit.MB)).build()));
}
public void testNoMemoryCapping() {
@ -276,6 +275,36 @@ public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase<D
assertThat(e.getMessage(), containsString("must be less than the value of the xpack.ml.max_model_memory_limit setting"));
}
public void testBuildForMemoryEstimation() {
DataFrameAnalyticsConfig.Builder builder = createRandomBuilder("foo");
DataFrameAnalyticsConfig config = builder.buildForMemoryEstimation();
assertThat(config, equalTo(builder.build()));
}
public void testBuildForMemoryEstimation_MissingId() {
DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder()
.setAnalysis(OutlierDetectionTests.createRandom())
.setSource(DataFrameAnalyticsSourceTests.createRandom())
.setDest(DataFrameAnalyticsDestTests.createRandom());
DataFrameAnalyticsConfig config = builder.buildForMemoryEstimation();
assertThat(config.getId(), equalTo("dummy"));
}
public void testBuildForMemoryEstimation_MissingDest() {
DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder()
.setId("foo")
.setAnalysis(OutlierDetectionTests.createRandom())
.setSource(DataFrameAnalyticsSourceTests.createRandom());
DataFrameAnalyticsConfig config = builder.buildForMemoryEstimation();
assertThat(config.getDest().getIndex(), equalTo("dummy"));
}
public void testPreventCreateTimeInjection() throws IOException {
String json = "{"
+ " \"create_time\" : 123456789 },"
@ -306,7 +335,7 @@ public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase<D
}
}
public void assertTooSmall(IllegalArgumentException e) {
assertThat(e.getMessage(), is("[model_memory_limit] must be at least [1mb]"));
private static void assertTooSmall(ElasticsearchStatusException e) {
assertThat(e.getMessage(), startsWith("model_memory_limit must be at least 1 MiB."));
}
}

View File

@ -105,8 +105,9 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
}
protected static DataFrameAnalyticsConfig buildOutlierDetectionAnalytics(String id, String[] sourceIndex, String destIndex,
@Nullable String resultsField) {
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(id);
@Nullable String resultsField) {
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder();
configBuilder.setId(id);
configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null));
configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField));
configBuilder.setAnalysis(new OutlierDetection());
@ -122,7 +123,8 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
protected static DataFrameAnalyticsConfig buildRegressionAnalytics(String id, String[] sourceIndex, String destIndex,
@Nullable String resultsField, String dependentVariable) {
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(id);
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder();
configBuilder.setId(id);
configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null));
configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField));
configBuilder.setAnalysis(new Regression(dependentVariable));

View File

@ -72,6 +72,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteFilterAction;
import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction;
import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction;
import org.elasticsearch.xpack.core.ml.action.FindFileStructureAction;
@ -136,6 +137,7 @@ import org.elasticsearch.xpack.ml.action.TransportDeleteFilterAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteForecastAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteJobAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteModelSnapshotAction;
import org.elasticsearch.xpack.ml.action.TransportEstimateMemoryUsageAction;
import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction;
import org.elasticsearch.xpack.ml.action.TransportFinalizeJobExecutionAction;
import org.elasticsearch.xpack.ml.action.TransportFindFileStructureAction;
@ -190,6 +192,10 @@ import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager;
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessFactory;
import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.MemoryUsageEstimationProcessManager;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory;
import org.elasticsearch.xpack.ml.dataframe.process.NativeAnalyticsProcessFactory;
import org.elasticsearch.xpack.ml.job.JobManager;
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
@ -235,6 +241,7 @@ import org.elasticsearch.xpack.ml.rest.datafeeds.RestStartDatafeedAction;
import org.elasticsearch.xpack.ml.rest.datafeeds.RestStopDatafeedAction;
import org.elasticsearch.xpack.ml.rest.datafeeds.RestUpdateDatafeedAction;
import org.elasticsearch.xpack.ml.rest.dataframe.RestDeleteDataFrameAnalyticsAction;
import org.elasticsearch.xpack.ml.rest.dataframe.RestEstimateMemoryUsageAction;
import org.elasticsearch.xpack.ml.rest.dataframe.RestEvaluateDataFrameAction;
import org.elasticsearch.xpack.ml.rest.dataframe.RestGetDataFrameAnalyticsAction;
import org.elasticsearch.xpack.ml.rest.dataframe.RestGetDataFrameAnalyticsStatsAction;
@ -487,7 +494,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
AutodetectProcessFactory autodetectProcessFactory;
NormalizerProcessFactory normalizerProcessFactory;
AnalyticsProcessFactory analyticsProcessFactory;
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory;
AnalyticsProcessFactory<MemoryUsageEstimationResult> memoryEstimationProcessFactory;
if (MachineLearningField.AUTODETECT_PROCESS.get(settings) && MachineLearningFeatureSet.isRunningOnMlPlatform(true)) {
try {
NativeController nativeController = NativeControllerHolder.getNativeController(clusterService.getNodeName(), environment);
@ -503,6 +511,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
clusterService);
normalizerProcessFactory = new NativeNormalizerProcessFactory(environment, nativeController, clusterService);
analyticsProcessFactory = new NativeAnalyticsProcessFactory(environment, nativeController, clusterService);
memoryEstimationProcessFactory =
new NativeMemoryUsageEstimationProcessFactory(environment, nativeController, clusterService);
} catch (IOException e) {
// The low level cause of failure from the named pipe helper's perspective is almost never the real root cause, so
// only log this at the lowest level of detail. It's almost always "file not found" on a named pipe we expect to be
@ -519,6 +529,7 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
// factor of 1.0 makes renormalization a no-op
normalizerProcessFactory = (jobId, quantilesState, bucketSpan, executorService) -> new MultiplyingNormalizerProcess(1.0);
analyticsProcessFactory = (jobId, analyticsProcessConfig, executorService, onProcessCrash) -> null;
memoryEstimationProcessFactory = (jobId, analyticsProcessConfig, executorService, onProcessCrash) -> null;
}
NormalizerFactory normalizerFactory = new NormalizerFactory(normalizerProcessFactory,
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME));
@ -542,6 +553,9 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
// Data frame analytics components
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory);
MemoryUsageEstimationProcessManager memoryEstimationProcessManager =
new MemoryUsageEstimationProcessManager(
threadPool.generic(), threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), memoryEstimationProcessFactory);
DataFrameAnalyticsConfigProvider dataFrameAnalyticsConfigProvider = new DataFrameAnalyticsConfigProvider(client);
assert client instanceof NodeClient;
DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager((NodeClient) client,
@ -579,6 +593,7 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
new MlAssignmentNotifier(settings, auditor, threadPool, client, clusterService),
memoryTracker,
analyticsProcessManager,
memoryEstimationProcessManager,
dataFrameAnalyticsConfigProvider,
nativeStorageProvider
);
@ -676,7 +691,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
new RestDeleteDataFrameAnalyticsAction(restController),
new RestStartDataFrameAnalyticsAction(restController),
new RestStopDataFrameAnalyticsAction(restController),
new RestEvaluateDataFrameAction(restController)
new RestEvaluateDataFrameAction(restController),
new RestEstimateMemoryUsageAction(restController)
);
}
@ -742,7 +758,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
new ActionHandler<>(DeleteDataFrameAnalyticsAction.INSTANCE, TransportDeleteDataFrameAnalyticsAction.class),
new ActionHandler<>(StartDataFrameAnalyticsAction.INSTANCE, TransportStartDataFrameAnalyticsAction.class),
new ActionHandler<>(StopDataFrameAnalyticsAction.INSTANCE, TransportStopDataFrameAnalyticsAction.class),
new ActionHandler<>(EvaluateDataFrameAction.INSTANCE, TransportEvaluateDataFrameAction.class)
new ActionHandler<>(EvaluateDataFrameAction.INSTANCE, TransportEvaluateDataFrameAction.class),
new ActionHandler<>(EstimateMemoryUsageAction.INSTANCE, TransportEstimateMemoryUsageAction.class)
);
}

View File

@ -0,0 +1,128 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.action;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.MemoryUsageEstimationProcessManager;
import java.util.Objects;
import java.util.Optional;
/**
* Estimates memory usage for the given data frame analytics spec.
* Redirects to a different node if the current node is *not* an ML node.
*/
public class TransportEstimateMemoryUsageAction
extends HandledTransportAction<EstimateMemoryUsageAction.Request, EstimateMemoryUsageAction.Response> {
private final TransportService transportService;
private final ClusterService clusterService;
private final NodeClient client;
private final MemoryUsageEstimationProcessManager processManager;
@Inject
public TransportEstimateMemoryUsageAction(TransportService transportService,
ActionFilters actionFilters,
ClusterService clusterService,
NodeClient client,
MemoryUsageEstimationProcessManager processManager) {
super(EstimateMemoryUsageAction.NAME, transportService, actionFilters, EstimateMemoryUsageAction.Request::new);
this.transportService = transportService;
this.clusterService = Objects.requireNonNull(clusterService);
this.client = Objects.requireNonNull(client);
this.processManager = Objects.requireNonNull(processManager);
}
@Override
protected void doExecute(Task task,
EstimateMemoryUsageAction.Request request,
ActionListener<EstimateMemoryUsageAction.Response> listener) {
DiscoveryNode localNode = clusterService.localNode();
if (MachineLearning.isMlNode(localNode)) {
doEstimateMemoryUsage(createTaskIdForMemoryEstimation(task), request, listener);
} else {
redirectToMlNode(request, listener);
}
}
/**
* Creates unique task id for the memory estimation process. This id is useful when logging.
*/
private static String createTaskIdForMemoryEstimation(Task task) {
return "memory_usage_estimation_" + task.getId();
}
/**
* Performs memory usage estimation.
* Memory usage estimation spawns an ML C++ process which is only available on ML nodes. That's why this method can only be called on
* the ML node.
*/
private void doEstimateMemoryUsage(String taskId,
EstimateMemoryUsageAction.Request request,
ActionListener<EstimateMemoryUsageAction.Response> listener) {
DataFrameDataExtractorFactory.createForSourceIndices(
client,
taskId,
request.getConfig(),
ActionListener.wrap(
dataExtractorFactory -> {
processManager.runJobAsync(
taskId,
request.getConfig(),
dataExtractorFactory,
ActionListener.wrap(
result -> listener.onResponse(
new EstimateMemoryUsageAction.Response(
result.getExpectedMemoryUsageWithOnePartition(), result.getExpectedMemoryUsageWithMaxPartitions())),
listener::onFailure
)
);
},
listener::onFailure
)
);
}
/**
* Finds the first available ML node in the cluster and redirects the request to this node.
*/
private void redirectToMlNode(EstimateMemoryUsageAction.Request request,
ActionListener<EstimateMemoryUsageAction.Response> listener) {
Optional<DiscoveryNode> node = findMlNode(clusterService.state());
if (node.isPresent()) {
transportService.sendRequest(
node.get(), actionName, request, new ActionListenerResponseHandler<>(listener, EstimateMemoryUsageAction.Response::new));
} else {
listener.onFailure(ExceptionsHelper.badRequestException("No ML node to run on"));
}
}
/**
* Finds the first available ML node in the cluster state.
*/
private static Optional<DiscoveryNode> findMlNode(ClusterState clusterState) {
for (DiscoveryNode node : clusterState.getNodes()) {
if (MachineLearning.isMlNode(node)) {
return Optional.of(node);
}
}
return Optional.empty();
}
}

View File

@ -210,7 +210,7 @@ public class DataFrameAnalyticsManager {
// TODO This could fail with errors. In that case we get stuck with the copied index.
// We could delete the index in case of failure or we could try building the factory before reindexing
// to catch the error early on.
DataFrameDataExtractorFactory.create(client, config, isTaskRestarting, dataExtractorFactoryListener);
DataFrameDataExtractorFactory.createForDestinationIndex(client, config, isTaskRestarting, dataExtractorFactoryListener);
}
public void stop(DataFrameAnalyticsTask task) {

View File

@ -29,6 +29,7 @@ import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
@ -37,15 +38,15 @@ public class DataFrameDataExtractorFactory {
private final Client client;
private final String analyticsId;
private final String index;
private final List<String> indices;
private final ExtractedFields extractedFields;
private final Map<String, String> headers;
private DataFrameDataExtractorFactory(Client client, String analyticsId, String index, ExtractedFields extractedFields,
private DataFrameDataExtractorFactory(Client client, String analyticsId, List<String> indices, ExtractedFields extractedFields,
Map<String, String> headers) {
this.client = Objects.requireNonNull(client);
this.analyticsId = Objects.requireNonNull(analyticsId);
this.index = Objects.requireNonNull(index);
this.indices = Objects.requireNonNull(indices);
this.extractedFields = Objects.requireNonNull(extractedFields);
this.headers = headers;
}
@ -54,7 +55,7 @@ public class DataFrameDataExtractorFactory {
DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(
analyticsId,
extractedFields,
Arrays.asList(index),
indices,
allExtractedFieldsExistQuery(),
1000,
headers,
@ -71,6 +72,34 @@ public class DataFrameDataExtractorFactory {
return query;
}
/**
* Validate and create a new extractor factory
*
* The source index must exist and contain at least 1 compatible field or validations will fail.
*
* @param client ES Client used to make calls against the cluster
* @param config The config from which to create the extractor factory
* @param listener The listener to notify on creation or failure
*/
public static void createForSourceIndices(Client client,
String taskId,
DataFrameAnalyticsConfig config,
ActionListener<DataFrameDataExtractorFactory> listener) {
validateIndexAndExtractFields(
client,
config.getSource().getIndex(),
config,
null,
false,
ActionListener.wrap(
extractedFields -> listener.onResponse(
new DataFrameDataExtractorFactory(
client, taskId, Arrays.asList(config.getSource().getIndex()), extractedFields, config.getHeaders())),
listener::onFailure
)
);
}
/**
* Validate and create a new extractor factory
*
@ -81,15 +110,23 @@ public class DataFrameDataExtractorFactory {
* @param isTaskRestarting Whether the task is restarting
* @param listener The listener to notify on creation or failure
*/
public static void create(Client client,
DataFrameAnalyticsConfig config,
boolean isTaskRestarting,
ActionListener<DataFrameDataExtractorFactory> listener) {
validateIndexAndExtractFields(client, new String[] {config.getDest().getIndex()}, config, isTaskRestarting,
ActionListener.wrap(extractedFields -> listener.onResponse(new DataFrameDataExtractorFactory(
client, config.getId(), config.getDest().getIndex(), extractedFields, config.getHeaders())),
public static void createForDestinationIndex(Client client,
DataFrameAnalyticsConfig config,
boolean isTaskRestarting,
ActionListener<DataFrameDataExtractorFactory> listener) {
validateIndexAndExtractFields(
client,
new String[] {config.getDest().getIndex()},
config,
config.getDest().getResultsField(),
isTaskRestarting,
ActionListener.wrap(
extractedFields -> listener.onResponse(
new DataFrameDataExtractorFactory(
client, config.getId(), Arrays.asList(config.getDest().getIndex()), extractedFields, config.getHeaders())),
listener::onFailure
));
)
);
}
/**
@ -102,26 +139,36 @@ public class DataFrameDataExtractorFactory {
public static void validateConfigAndSourceIndex(Client client,
DataFrameAnalyticsConfig config,
ActionListener<DataFrameAnalyticsConfig> listener) {
validateIndexAndExtractFields(client, config.getSource().getIndex(), config, false, ActionListener.wrap(
validateIndexAndExtractFields(
client,
config.getSource().getIndex(),
config,
config.getDest().getResultsField(),
false,
ActionListener.wrap(
fields -> {
config.getSource().getParsedQuery(); // validate query is acceptable
listener.onResponse(config);
},
listener::onFailure
));
)
);
}
private static void validateIndexAndExtractFields(Client client,
String[] index,
DataFrameAnalyticsConfig config,
String resultsField,
boolean isTaskRestarting,
ActionListener<ExtractedFields> listener) {
AtomicInteger docValueFieldsLimitHolder = new AtomicInteger();
// Step 3. Extract fields (if possible) and notify listener
ActionListener<FieldCapabilitiesResponse> fieldCapabilitiesHandler = ActionListener.wrap(
fieldCapabilitiesResponse -> listener.onResponse(new ExtractedFieldsDetector(index, config, isTaskRestarting,
docValueFieldsLimitHolder.get(), fieldCapabilitiesResponse).detect()),
fieldCapabilitiesResponse -> listener.onResponse(
new ExtractedFieldsDetector(
index, config, resultsField, isTaskRestarting, docValueFieldsLimitHolder.get(), fieldCapabilitiesResponse)
.detect()),
listener::onFailure
);

View File

@ -60,14 +60,16 @@ public class ExtractedFieldsDetector {
private final String[] index;
private final DataFrameAnalyticsConfig config;
private final String resultsField;
private final boolean isTaskRestarting;
private final int docValueFieldsLimit;
private final FieldCapabilitiesResponse fieldCapabilitiesResponse;
ExtractedFieldsDetector(String[] index, DataFrameAnalyticsConfig config, boolean isTaskRestarting, int docValueFieldsLimit,
FieldCapabilitiesResponse fieldCapabilitiesResponse) {
ExtractedFieldsDetector(String[] index, DataFrameAnalyticsConfig config, String resultsField, boolean isTaskRestarting,
int docValueFieldsLimit, FieldCapabilitiesResponse fieldCapabilitiesResponse) {
this.index = Objects.requireNonNull(index);
this.config = Objects.requireNonNull(config);
this.resultsField = resultsField;
this.isTaskRestarting = isTaskRestarting;
this.docValueFieldsLimit = docValueFieldsLimit;
this.fieldCapabilitiesResponse = Objects.requireNonNull(fieldCapabilitiesResponse);
@ -76,12 +78,7 @@ public class ExtractedFieldsDetector {
public ExtractedFields detect() {
Set<String> fields = new HashSet<>(fieldCapabilitiesResponse.get().keySet());
fields.removeAll(IGNORE_FIELDS);
checkResultsFieldIsNotPresent();
// Ignore fields under the results object
fields.removeIf(field -> field.startsWith(config.getDest().getResultsField() + "."));
removeFieldsUnderResultsField(fields);
includeAndExcludeFields(fields);
removeFieldsWithIncompatibleTypes(fields);
checkRequiredFieldsArePresent(fields);
@ -105,17 +102,28 @@ public class ExtractedFieldsDetector {
return extractedFields;
}
private void removeFieldsUnderResultsField(Set<String> fields) {
if (resultsField == null) {
return;
}
checkResultsFieldIsNotPresent();
// Ignore fields under the results object
fields.removeIf(field -> field.startsWith(resultsField + "."));
}
private void checkResultsFieldIsNotPresent() {
// If the task is restarting we do not mind the index containing the results field, we will overwrite all docs
if (isTaskRestarting) {
return;
}
Map<String, FieldCapabilities> indexToFieldCaps = fieldCapabilitiesResponse.getField(config.getDest().getResultsField());
Map<String, FieldCapabilities> indexToFieldCaps = fieldCapabilitiesResponse.getField(resultsField);
if (indexToFieldCaps != null && indexToFieldCaps.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("A field that matches the {}.{} [{}] already exists;" +
" please set a different {}", DataFrameAnalyticsConfig.DEST.getPreferredName(),
DataFrameAnalyticsDest.RESULTS_FIELD.getPreferredName(), config.getDest().getResultsField(),
throw ExceptionsHelper.badRequestException(
"A field that matches the {}.{} [{}] already exists; please set a different {}",
DataFrameAnalyticsConfig.DEST.getPreferredName(),
DataFrameAnalyticsDest.RESULTS_FIELD.getPreferredName(),
resultsField,
DataFrameAnalyticsDest.RESULTS_FIELD.getPreferredName());
}
}

View File

@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.xpack.ml.process.AbstractNativeProcess;
import org.elasticsearch.xpack.ml.process.ProcessResultsParser;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Path;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;
abstract class AbstractNativeAnalyticsProcess<Result> extends AbstractNativeProcess implements AnalyticsProcess<Result> {
private final String name;
private final ProcessResultsParser<Result> resultsParser;
protected AbstractNativeAnalyticsProcess(String name, ConstructingObjectParser<Result, Void> resultParser, String jobId,
InputStream logStream, OutputStream processInStream,
InputStream processOutStream, OutputStream processRestoreStream, int numberOfFields,
List<Path> filesToDelete, Consumer<String> onProcessCrash) {
super(jobId, logStream, processInStream, processOutStream, processRestoreStream, numberOfFields, filesToDelete, onProcessCrash);
this.name = Objects.requireNonNull(name);
this.resultsParser = new ProcessResultsParser<>(Objects.requireNonNull(resultParser));
}
@Override
public String getName() {
return name;
}
@Override
public void persistState() {
// Nothing to persist
}
@Override
public void writeEndOfDataMessage() throws IOException {
new AnalyticsControlMessageWriter(recordWriter(), numberOfFields()).writeEndOfData();
}
@Override
public Iterator<Result> readAnalyticsResults() {
return resultsParser.parseResults(processOutStream());
}
}

View File

@ -9,7 +9,6 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.env.Environment;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.ProcessPipes;
@ -21,6 +20,7 @@ import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
public class AnalyticsBuilder {
@ -29,38 +29,49 @@ public class AnalyticsBuilder {
private static final String LENGTH_ENCODED_INPUT_ARG = "--lengthEncodedInput";
private static final String CONFIG_ARG = "--config=";
private static final String MEMORY_USAGE_ESTIMATION_ONLY_ARG = "--memoryUsageEstimationOnly";
private final Environment env;
private final Supplier<Path> tempDirPathSupplier;
private final NativeController nativeController;
private final ProcessPipes processPipes;
private final AnalyticsProcessConfig config;
private final List<Path> filesToDelete;
private boolean performMemoryUsageEstimationOnly;
public AnalyticsBuilder(Environment env, NativeController nativeController, ProcessPipes processPipes, AnalyticsProcessConfig config,
List<Path> filesToDelete) {
this.env = Objects.requireNonNull(env);
public AnalyticsBuilder(Supplier<Path> tempDirPathSupplier, NativeController nativeController,
ProcessPipes processPipes, AnalyticsProcessConfig config, List<Path> filesToDelete) {
this.tempDirPathSupplier = Objects.requireNonNull(tempDirPathSupplier);
this.nativeController = Objects.requireNonNull(nativeController);
this.processPipes = Objects.requireNonNull(processPipes);
this.config = Objects.requireNonNull(config);
this.filesToDelete = Objects.requireNonNull(filesToDelete);
}
public AnalyticsBuilder performMemoryUsageEstimationOnly() {
this.performMemoryUsageEstimationOnly = true;
return this;
}
public void build() throws IOException {
List<String> command = buildAnalyticsCommand();
processPipes.addArgs(command);
nativeController.startProcess(command);
}
List<String> buildAnalyticsCommand() throws IOException {
private List<String> buildAnalyticsCommand() throws IOException {
List<String> command = new ArrayList<>();
command.add(ANALYTICS_PATH);
command.add(LENGTH_ENCODED_INPUT_ARG);
addConfigFile(command);
if (performMemoryUsageEstimationOnly) {
command.add(MEMORY_USAGE_ESTIMATION_ONLY_ARG);
}
return command;
}
private void addConfigFile(List<String> command) throws IOException {
Path configFile = Files.createTempFile(env.tmpFile(), "analysis", ".conf");
Path tempDir = tempDirPathSupplier.get();
Path configFile = Files.createTempFile(tempDir, "analysis", ".conf");
filesToDelete.add(configFile);
try (OutputStreamWriter osw = new OutputStreamWriter(Files.newOutputStream(configFile),StandardCharsets.UTF_8);
XContentBuilder jsonBuilder = JsonXContent.contentBuilder()) {

View File

@ -10,7 +10,7 @@ import org.elasticsearch.xpack.ml.process.NativeProcess;
import java.io.IOException;
import java.util.Iterator;
public interface AnalyticsProcess extends NativeProcess {
public interface AnalyticsProcess<ProcessResult> extends NativeProcess {
/**
* Writes a control message that informs the process
@ -22,7 +22,7 @@ public interface AnalyticsProcess extends NativeProcess {
/**
* @return stream of data frame analytics results.
*/
Iterator<AnalyticsResult> readAnalyticsResults();
Iterator<ProcessResult> readAnalyticsResults();
/**
* Read anything left in the stream before

View File

@ -8,7 +8,7 @@ package org.elasticsearch.xpack.ml.dataframe.process;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
public interface AnalyticsProcessFactory {
public interface AnalyticsProcessFactory<ProcessResult> {
/**
* Create an implementation of {@link AnalyticsProcess}
@ -19,6 +19,6 @@ public interface AnalyticsProcessFactory {
* @param onProcessCrash Callback to execute if the process stops unexpectedly
* @return The process
*/
AnalyticsProcess createAnalyticsProcess(String jobId, AnalyticsProcessConfig analyticsProcessConfig, ExecutorService executorService,
Consumer<String> onProcessCrash);
AnalyticsProcess<ProcessResult> createAnalyticsProcess(String jobId, AnalyticsProcessConfig analyticsProcessConfig,
ExecutorService executorService, Consumer<String> onProcessCrash);
}

View File

@ -21,6 +21,7 @@ import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import java.io.IOException;
import java.util.List;
@ -39,10 +40,12 @@ public class AnalyticsProcessManager {
private final Client client;
private final ThreadPool threadPool;
private final AnalyticsProcessFactory processFactory;
private final AnalyticsProcessFactory<AnalyticsResult> processFactory;
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
public AnalyticsProcessManager(Client client, ThreadPool threadPool, AnalyticsProcessFactory analyticsProcessFactory) {
public AnalyticsProcessManager(Client client,
ThreadPool threadPool,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory) {
this.client = Objects.requireNonNull(client);
this.threadPool = Objects.requireNonNull(threadPool);
this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
@ -83,7 +86,8 @@ public class AnalyticsProcessManager {
}
private void processData(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor,
AnalyticsProcess process, AnalyticsResultProcessor resultProcessor, Consumer<Exception> finishHandler) {
AnalyticsProcess<AnalyticsResult> process, AnalyticsResultProcessor resultProcessor,
Consumer<Exception> finishHandler) {
try {
writeHeaderRecord(dataExtractor, process);
@ -118,7 +122,7 @@ public class AnalyticsProcessManager {
}
}
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process) throws IOException {
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process) throws IOException {
// The extra fields are for the doc hash and the control field (should be an empty string)
String[] record = new String[dataExtractor.getFieldNames().size() + 2];
// The value of the control field should be an empty string for data frame rows
@ -139,7 +143,7 @@ public class AnalyticsProcessManager {
}
}
private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsProcess process) throws IOException {
private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process) throws IOException {
List<String> fieldNames = dataExtractor.getFieldNames();
// We add 2 extra fields, both named dot:
@ -155,9 +159,9 @@ public class AnalyticsProcessManager {
process.writeRecord(headerRecord);
}
private AnalyticsProcess createProcess(DataFrameAnalyticsTask task, AnalyticsProcessConfig analyticsProcessConfig) {
private AnalyticsProcess<AnalyticsResult> createProcess(DataFrameAnalyticsTask task, AnalyticsProcessConfig analyticsProcessConfig) {
ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME);
AnalyticsProcess process = processFactory.createAnalyticsProcess(task.getParams().getId(), analyticsProcessConfig,
AnalyticsProcess<AnalyticsResult> process = processFactory.createAnalyticsProcess(task.getParams().getId(), analyticsProcessConfig,
executorService, onProcessCrash(task));
if (process.isProcessAlive() == false) {
throw ExceptionsHelper.serverError("Failed to start data frame analytics process");
@ -215,7 +219,7 @@ public class AnalyticsProcessManager {
class ProcessContext {
private final String id;
private volatile AnalyticsProcess process;
private volatile AnalyticsProcess<AnalyticsResult> process;
private volatile DataFrameDataExtractor dataExtractor;
private volatile AnalyticsResultProcessor resultProcessor;
private final AtomicInteger progressPercent = new AtomicInteger(0);

View File

@ -9,6 +9,7 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import java.util.Iterator;
@ -53,7 +54,7 @@ public class AnalyticsResultProcessor {
}
}
public void process(AnalyticsProcess process) {
public void process(AnalyticsProcess<AnalyticsResult> process) {
// TODO When java 9 features can be used, we will not need the local variable here
try (DataFrameRowsJoiner resultsJoiner = dataFrameRowsJoiner) {
Iterator<AnalyticsResult> iterator = process.readAnalyticsResults();

View File

@ -0,0 +1,143 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import java.io.IOException;
import java.util.Iterator;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
public class MemoryUsageEstimationProcessManager {
private static final Logger LOGGER = LogManager.getLogger(MemoryUsageEstimationProcessManager.class);
private final ExecutorService executorServiceForJob;
private final ExecutorService executorServiceForProcess;
private final AnalyticsProcessFactory<MemoryUsageEstimationResult> processFactory;
public MemoryUsageEstimationProcessManager(ExecutorService executorServiceForJob,
ExecutorService executorServiceForProcess,
AnalyticsProcessFactory<MemoryUsageEstimationResult> processFactory) {
this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob);
this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess);
this.processFactory = Objects.requireNonNull(processFactory);
}
public void runJobAsync(String jobId,
DataFrameAnalyticsConfig config,
DataFrameDataExtractorFactory dataExtractorFactory,
ActionListener<MemoryUsageEstimationResult> listener) {
executorServiceForJob.execute(() -> {
try {
MemoryUsageEstimationResult result = runJob(jobId, config, dataExtractorFactory);
listener.onResponse(result);
} catch (Exception e) {
listener.onFailure(e);
}
});
}
private MemoryUsageEstimationResult runJob(String jobId,
DataFrameAnalyticsConfig config,
DataFrameDataExtractorFactory dataExtractorFactory) {
DataFrameDataExtractor dataExtractor = dataExtractorFactory.newExtractor(false);
DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary();
Set<String> categoricalFields = dataExtractor.getCategoricalFields();
if (dataSummary.rows == 0) {
return new MemoryUsageEstimationResult(ByteSizeValue.ZERO, ByteSizeValue.ZERO);
}
AnalyticsProcessConfig processConfig =
new AnalyticsProcessConfig(
dataSummary.rows,
dataSummary.cols,
DataFrameAnalyticsConfig.MIN_MODEL_MEMORY_LIMIT,
1,
"",
categoricalFields,
config.getAnalysis());
ProcessHolder processHolder = new ProcessHolder();
AnalyticsProcess<MemoryUsageEstimationResult> process =
processFactory.createAnalyticsProcess(
jobId,
processConfig,
executorServiceForProcess,
onProcessCrash(jobId, processHolder));
processHolder.process = process;
if (process.isProcessAlive() == false) {
String errorMsg = new ParameterizedMessage("[{}] Error while starting process", jobId).getFormattedMessage();
throw ExceptionsHelper.serverError(errorMsg);
}
try {
return readResult(jobId, process);
} catch (Exception e) {
String errorMsg =
new ParameterizedMessage("[{}] Error while processing result [{}]", jobId, e.getMessage()).getFormattedMessage();
throw ExceptionsHelper.serverError(errorMsg, e);
} finally {
process.consumeAndCloseOutputStream();
try {
LOGGER.info("[{}] Closing process", jobId);
process.close();
LOGGER.info("[{}] Closed process", jobId);
} catch (Exception e) {
String errorMsg =
new ParameterizedMessage("[{}] Error while closing process [{}]", jobId, e.getMessage()).getFormattedMessage();
throw ExceptionsHelper.serverError(errorMsg, e);
}
}
}
private static class ProcessHolder {
volatile AnalyticsProcess<MemoryUsageEstimationResult> process;
}
private static Consumer<String> onProcessCrash(String jobId, ProcessHolder processHolder) {
return reason -> {
AnalyticsProcess<MemoryUsageEstimationResult> process = processHolder.process;
if (process == null) {
LOGGER.error(new ParameterizedMessage("[{}] Process does not exist", jobId));
return;
}
try {
process.kill();
} catch (IOException e) {
LOGGER.error(new ParameterizedMessage("[{}] Failed to kill process", jobId), e);
}
};
}
/**
* Extracts {@link MemoryUsageEstimationResult} from process' output.
*/
private static MemoryUsageEstimationResult readResult(String jobId, AnalyticsProcess<MemoryUsageEstimationResult> process) {
Iterator<MemoryUsageEstimationResult> iterator = process.readAnalyticsResults();
if (iterator.hasNext() == false) {
String errorMsg =
new ParameterizedMessage("[{}] Memory usage estimation process returned no results", jobId).getFormattedMessage();
throw ExceptionsHelper.serverError(errorMsg);
}
MemoryUsageEstimationResult result = iterator.next();
if (iterator.hasNext()) {
String errorMsg =
new ParameterizedMessage("[{}] Memory usage estimation process returned more than one result", jobId).getFormattedMessage();
throw ExceptionsHelper.serverError(errorMsg);
}
return result;
}
}

View File

@ -5,46 +5,22 @@
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.xpack.ml.process.AbstractNativeProcess;
import org.elasticsearch.xpack.ml.process.ProcessResultsParser;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Path;
import java.util.Iterator;
import java.util.List;
import java.util.function.Consumer;
public class NativeAnalyticsProcess extends AbstractNativeProcess implements AnalyticsProcess {
public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess<AnalyticsResult> {
private static final String NAME = "analytics";
private final ProcessResultsParser<AnalyticsResult> resultsParser = new ProcessResultsParser<>(AnalyticsResult.PARSER);
protected NativeAnalyticsProcess(String jobId, InputStream logStream, OutputStream processInStream, InputStream processOutStream,
OutputStream processRestoreStream, int numberOfFields, List<Path> filesToDelete,
Consumer<String> onProcessCrash) {
super(jobId, logStream, processInStream, processOutStream, processRestoreStream, numberOfFields, filesToDelete, onProcessCrash);
}
@Override
public String getName() {
return NAME;
}
@Override
public void persistState() {
// Nothing to persist
}
@Override
public void writeEndOfDataMessage() throws IOException {
new AnalyticsControlMessageWriter(recordWriter(), numberOfFields()).writeEndOfData();
}
@Override
public Iterator<AnalyticsResult> readAnalyticsResults() {
return resultsParser.parseResults(processOutStream());
protected NativeAnalyticsProcess(String jobId, InputStream logStream, OutputStream processInStream,
InputStream processOutStream, OutputStream processRestoreStream, int numberOfFields,
List<Path> filesToDelete, Consumer<String> onProcessCrash) {
super(NAME, AnalyticsResult.PARSER, jobId, logStream, processInStream, processOutStream, processRestoreStream, numberOfFields,
filesToDelete, onProcessCrash);
}
}

View File

@ -14,6 +14,7 @@ import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.env.Environment;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.ProcessPipes;
import org.elasticsearch.xpack.ml.utils.NamedPipeHelper;
@ -27,7 +28,7 @@ import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory {
public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory<AnalyticsResult> {
private static final Logger LOGGER = LogManager.getLogger(NativeAnalyticsProcessFactory.class);
@ -50,7 +51,7 @@ public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory {
}
@Override
public AnalyticsProcess createAnalyticsProcess(String jobId, AnalyticsProcessConfig analyticsProcessConfig,
public NativeAnalyticsProcess createAnalyticsProcess(String jobId, AnalyticsProcessConfig analyticsProcessConfig,
ExecutorService executorService, Consumer<String> onProcessCrash) {
List<Path> filesToDelete = new ArrayList<>();
ProcessPipes processPipes = new ProcessPipes(env, NAMED_PIPE_HELPER, AnalyticsBuilder.ANALYTICS, jobId,
@ -80,8 +81,8 @@ public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory {
private void createNativeProcess(String jobId, AnalyticsProcessConfig analyticsProcessConfig, List<Path> filesToDelete,
ProcessPipes processPipes) {
AnalyticsBuilder analyticsBuilder = new AnalyticsBuilder(env, nativeController, processPipes, analyticsProcessConfig,
filesToDelete);
AnalyticsBuilder analyticsBuilder =
new AnalyticsBuilder(env::tmpFile, nativeController, processPipes, analyticsProcessConfig, filesToDelete);
try {
analyticsBuilder.build();
processPipes.connectStreams(processConnectTimeout);

View File

@ -0,0 +1,27 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.file.Path;
import java.util.List;
import java.util.function.Consumer;
public class NativeMemoryUsageEstimationProcess extends AbstractNativeAnalyticsProcess<MemoryUsageEstimationResult> {
private static final String NAME = "memory_usage_estimation";
protected NativeMemoryUsageEstimationProcess(String jobId, InputStream logStream,
OutputStream processInStream, InputStream processOutStream,
OutputStream processRestoreStream, int numberOfFields, List<Path> filesToDelete,
Consumer<String> onProcessCrash) {
super(NAME, MemoryUsageEstimationResult.PARSER, jobId, logStream, processInStream, processOutStream, processRestoreStream,
numberOfFields, filesToDelete, onProcessCrash);
}
}

View File

@ -0,0 +1,103 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.env.Environment;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.ProcessPipes;
import org.elasticsearch.xpack.ml.utils.NamedPipeHelper;
import java.io.IOException;
import java.nio.file.Path;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
public class NativeMemoryUsageEstimationProcessFactory implements AnalyticsProcessFactory<MemoryUsageEstimationResult> {
private static final Logger LOGGER = LogManager.getLogger(NativeMemoryUsageEstimationProcessFactory.class);
private static final NamedPipeHelper NAMED_PIPE_HELPER = new NamedPipeHelper();
private final Environment env;
private final NativeController nativeController;
private volatile Duration processConnectTimeout;
public NativeMemoryUsageEstimationProcessFactory(Environment env, NativeController nativeController, ClusterService clusterService) {
this.env = Objects.requireNonNull(env);
this.nativeController = Objects.requireNonNull(nativeController);
setProcessConnectTimeout(MachineLearning.PROCESS_CONNECT_TIMEOUT.get(env.settings()));
clusterService.getClusterSettings().addSettingsUpdateConsumer(
MachineLearning.PROCESS_CONNECT_TIMEOUT, this::setProcessConnectTimeout);
}
void setProcessConnectTimeout(TimeValue processConnectTimeout) {
this.processConnectTimeout = Duration.ofMillis(processConnectTimeout.getMillis());
}
@Override
public NativeMemoryUsageEstimationProcess createAnalyticsProcess(
String jobId,
AnalyticsProcessConfig analyticsProcessConfig,
ExecutorService executorService,
Consumer<String> onProcessCrash) {
List<Path> filesToDelete = new ArrayList<>();
ProcessPipes processPipes = new ProcessPipes(
env, NAMED_PIPE_HELPER, AnalyticsBuilder.ANALYTICS, jobId, true, false, false, true, false, false);
createNativeProcess(jobId, analyticsProcessConfig, filesToDelete, processPipes);
NativeMemoryUsageEstimationProcess process = new NativeMemoryUsageEstimationProcess(
jobId,
processPipes.getLogStream().get(),
// Memory estimation process does not use the input pipe, hence null.
null,
processPipes.getProcessOutStream().get(),
null,
0,
filesToDelete,
onProcessCrash);
try {
process.start(executorService);
return process;
} catch (EsRejectedExecutionException e) {
try {
IOUtils.close(process);
} catch (IOException ioe) {
LOGGER.error("Can't close data frame analytics memory usage estimation process", ioe);
}
throw e;
}
}
private void createNativeProcess(String jobId, AnalyticsProcessConfig analyticsProcessConfig, List<Path> filesToDelete,
ProcessPipes processPipes) {
AnalyticsBuilder analyticsBuilder =
new AnalyticsBuilder(env::tmpFile, nativeController, processPipes, analyticsProcessConfig, filesToDelete)
.performMemoryUsageEstimationOnly();
try {
analyticsBuilder.build();
processPipes.connectStreams(processConnectTimeout);
} catch (IOException e) {
String msg = "Failed to launch data frame analytics memory usage estimation process for job " + jobId;
LOGGER.error(msg);
throw ExceptionsHelper.serverError(msg, e);
}
}
}

View File

@ -3,29 +3,30 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process;
package org.elasticsearch.xpack.ml.dataframe.process.results;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class AnalyticsResult implements ToXContentObject {
public static final ParseField TYPE = new ParseField("analytics_result");
public static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent");
static final ConstructingObjectParser<AnalyticsResult, Void> PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(),
public static final ConstructingObjectParser<AnalyticsResult, Void> PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(),
a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1]));
static {
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), PROGRESS_PERCENT);
PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE);
PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT);
}
private final RowResults rowResults;

View File

@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.results;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
public class MemoryUsageEstimationResult implements ToXContentObject {
public static final ParseField TYPE = new ParseField("memory_usage_estimation_result");
public static final ParseField EXPECTED_MEMORY_USAGE_WITH_ONE_PARTITION = new ParseField("expected_memory_usage_with_one_partition");
public static final ParseField EXPECTED_MEMORY_USAGE_WITH_MAX_PARTITIONS = new ParseField("expected_memory_usage_with_max_partitions");
public static final ConstructingObjectParser<MemoryUsageEstimationResult, Void> PARSER =
new ConstructingObjectParser<>(
TYPE.getPreferredName(),
true,
args -> new MemoryUsageEstimationResult((ByteSizeValue) args[0], (ByteSizeValue) args[1]));
static {
PARSER.declareField(
optionalConstructorArg(),
(p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), EXPECTED_MEMORY_USAGE_WITH_ONE_PARTITION.getPreferredName()),
EXPECTED_MEMORY_USAGE_WITH_ONE_PARTITION,
ObjectParser.ValueType.VALUE);
PARSER.declareField(
optionalConstructorArg(),
(p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), EXPECTED_MEMORY_USAGE_WITH_MAX_PARTITIONS.getPreferredName()),
EXPECTED_MEMORY_USAGE_WITH_MAX_PARTITIONS,
ObjectParser.ValueType.VALUE);
}
private final ByteSizeValue expectedMemoryUsageWithOnePartition;
private final ByteSizeValue expectedMemoryUsageWithMaxPartitions;
public MemoryUsageEstimationResult(@Nullable ByteSizeValue expectedMemoryUsageWithOnePartition,
@Nullable ByteSizeValue expectedMemoryUsageWithMaxPartitions) {
this.expectedMemoryUsageWithOnePartition = expectedMemoryUsageWithOnePartition;
this.expectedMemoryUsageWithMaxPartitions = expectedMemoryUsageWithMaxPartitions;
}
public ByteSizeValue getExpectedMemoryUsageWithOnePartition() {
return expectedMemoryUsageWithOnePartition;
}
public ByteSizeValue getExpectedMemoryUsageWithMaxPartitions() {
return expectedMemoryUsageWithMaxPartitions;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (expectedMemoryUsageWithOnePartition != null) {
builder.field(
EXPECTED_MEMORY_USAGE_WITH_ONE_PARTITION.getPreferredName(), expectedMemoryUsageWithOnePartition.getStringRep());
}
if (expectedMemoryUsageWithMaxPartitions != null) {
builder.field(
EXPECTED_MEMORY_USAGE_WITH_MAX_PARTITIONS.getPreferredName(), expectedMemoryUsageWithMaxPartitions.getStringRep());
}
builder.endObject();
return builder;
}
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null || getClass() != other.getClass()) {
return false;
}
MemoryUsageEstimationResult that = (MemoryUsageEstimationResult) other;
return Objects.equals(expectedMemoryUsageWithOnePartition, that.expectedMemoryUsageWithOnePartition)
&& Objects.equals(expectedMemoryUsageWithMaxPartitions, that.expectedMemoryUsageWithMaxPartitions);
}
@Override
public int hashCode() {
return Objects.hash(expectedMemoryUsageWithOnePartition, expectedMemoryUsageWithMaxPartitions);
}
}

View File

@ -14,6 +14,8 @@ import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
public class RowResults implements ToXContentObject {
public static final ParseField TYPE = new ParseField("row_results");
@ -25,8 +27,8 @@ public class RowResults implements ToXContentObject {
a -> new RowResults((Integer) a[0], (Map<String, Object>) a[1]));
static {
PARSER.declareInt(ConstructingObjectParser.constructorArg(), CHECKSUM);
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, context) -> p.map(), RESULTS);
PARSER.declareInt(constructorArg(), CHECKSUM);
PARSER.declareObject(constructorArg(), (p, context) -> p.map(), RESULTS);
}
private final int checksum;

View File

@ -62,7 +62,7 @@ public abstract class AbstractNativeProcess implements NativeProcess {
Consumer<String> onProcessCrash) {
this.jobId = jobId;
cppLogHandler = new CppLogMessageHandler(jobId, logStream);
this.processInStream = new BufferedOutputStream(processInStream);
this.processInStream = processInStream != null ? new BufferedOutputStream(processInStream) : null;
this.processOutStream = processOutStream;
this.processRestoreStream = processRestoreStream;
this.recordWriter = new LengthEncodedWriter(this.processInStream);
@ -87,19 +87,32 @@ public abstract class AbstractNativeProcess implements NativeProcess {
LOGGER.error(new ParameterizedMessage("[{}] Error tailing {} process logs", jobId, getName()), e);
}
} finally {
if (processCloseInitiated == false && processKilled == false) {
// The log message doesn't say "crashed", as the process could have been killed
// by a user or other process (e.g. the Linux OOM killer)
String errors = cppLogHandler.getErrors();
String fullError = String.format(Locale.ROOT, "[%s] %s process stopped unexpectedly: %s", jobId, getName(), errors);
LOGGER.error(fullError);
onProcessCrash.accept(fullError);
}
detectCrash();
}
});
}
/**
* Try detecting whether the process crashed i.e. stopped prematurely without any known reason.
*/
private void detectCrash() {
if (processCloseInitiated || processKilled) {
// Do not detect crash when the process is being closed or killed.
return;
}
if (processInStream == null) {
// Do not detect crash when the process has been closed automatically.
// This is possible when the process does not have input pipe to hang on and closes right after writing its output.
return;
}
// The log message doesn't say "crashed", as the process could have been killed
// by a user or other process (e.g. the Linux OOM killer)
String errors = cppLogHandler.getErrors();
String fullError = String.format(Locale.ROOT, "[%s] %s process stopped unexpectedly: %s", jobId, getName(), errors);
LOGGER.error(fullError);
onProcessCrash.accept(fullError);
}
/**
* Starts a process that may persist its state
* @param executorService the executor service to run on
@ -147,7 +160,9 @@ public abstract class AbstractNativeProcess implements NativeProcess {
try {
processCloseInitiated = true;
// closing its input causes the process to exit
processInStream.close();
if (processInStream != null) {
processInStream.close();
}
// wait for the process to exit by waiting for end-of-file on the named pipe connected
// to the state processor - it may take a long time for all the model state to be
// indexed
@ -192,7 +207,9 @@ public abstract class AbstractNativeProcess implements NativeProcess {
LOGGER.warn("[{}] Failed to get PID of {} process to kill", jobId, getName());
} finally {
try {
processInStream.close();
if (processInStream != null) {
processInStream.close();
}
} catch (IOException e) {
// Ignore it - we're shutting down and the method itself has logged a warning
}

View File

@ -0,0 +1,37 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.rest.dataframe;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction;
import org.elasticsearch.xpack.ml.MachineLearning;
import java.io.IOException;
public class RestEstimateMemoryUsageAction extends BaseRestHandler {
public RestEstimateMemoryUsageAction(RestController controller) {
controller.registerHandler(
RestRequest.Method.POST,
MachineLearning.BASE_PATH + "data_frame/analytics/_estimate_memory_usage", this);
}
@Override
public String getName() {
return "ml_estimate_memory_usage_action";
}
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
EstimateMemoryUsageAction.Request request =
EstimateMemoryUsageAction.Request.parseRequest(restRequest.contentOrSourceParamParser());
return channel -> client.execute(EstimateMemoryUsageAction.INSTANCE, request, new RestToXContentListener<>(channel));
}
}

View File

@ -58,7 +58,8 @@ public class DataFrameAnalyticsIndexTests extends ESTestCase {
private static final String[] SOURCE_INDEX = new String[] {"source-index"};
private static final String DEST_INDEX = "dest-index";
private static final DataFrameAnalyticsConfig ANALYTICS_CONFIG =
new DataFrameAnalyticsConfig.Builder(ANALYTICS_ID)
new DataFrameAnalyticsConfig.Builder()
.setId(ANALYTICS_ID)
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null))
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, null))
.setAnalysis(new OutlierDetection())

View File

@ -63,7 +63,8 @@ public class SourceDestValidatorTests extends ESTestCase {
}
public void testCheck_GivenSimpleSourceIndexAndValidDestIndex() {
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder("test")
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId("test")
.setSource(createSource("source-1"))
.setDest(new DataFrameAnalyticsDest("dest", null))
.setAnalysis(new OutlierDetection())
@ -74,7 +75,8 @@ public class SourceDestValidatorTests extends ESTestCase {
}
public void testCheck_GivenMissingConcreteSourceIndex() {
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder("test")
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId("test")
.setSource(createSource("missing"))
.setDest(new DataFrameAnalyticsDest("dest", null))
.setAnalysis(new OutlierDetection())
@ -88,7 +90,8 @@ public class SourceDestValidatorTests extends ESTestCase {
}
public void testCheck_GivenMissingWildcardSourceIndex() {
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder("test")
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId("test")
.setSource(createSource("missing*"))
.setDest(new DataFrameAnalyticsDest("dest", null))
.setAnalysis(new OutlierDetection())
@ -102,7 +105,8 @@ public class SourceDestValidatorTests extends ESTestCase {
}
public void testCheck_GivenDestIndexSameAsSourceIndex() {
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder("test")
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId("test")
.setSource(createSource("source-1"))
.setDest(new DataFrameAnalyticsDest("source-1", null))
.setAnalysis(new OutlierDetection())
@ -116,7 +120,8 @@ public class SourceDestValidatorTests extends ESTestCase {
}
public void testCheck_GivenDestIndexMatchesSourceIndex() {
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder("test")
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId("test")
.setSource(createSource("source-*"))
.setDest(new DataFrameAnalyticsDest(SOURCE_2, null))
.setAnalysis(new OutlierDetection())
@ -130,7 +135,8 @@ public class SourceDestValidatorTests extends ESTestCase {
}
public void testCheck_GivenDestIndexMatchesOneOfSourceIndices() {
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder("test")
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId("test")
.setSource(createSource("source-1,source-*"))
.setDest(new DataFrameAnalyticsDest(SOURCE_2, null))
.setAnalysis(new OutlierDetection())
@ -144,7 +150,8 @@ public class SourceDestValidatorTests extends ESTestCase {
}
public void testCheck_GivenDestIndexIsAliasThatMatchesMultipleIndices() {
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder("test")
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId("test")
.setSource(createSource(SOURCE_1))
.setDest(new DataFrameAnalyticsDest("dest-alias", null))
.setAnalysis(new OutlierDetection())
@ -159,7 +166,8 @@ public class SourceDestValidatorTests extends ESTestCase {
}
public void testCheck_GivenDestIndexIsAliasThatIsIncludedInSource() {
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder("test")
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId("test")
.setSource(createSource("source-1"))
.setDest(new DataFrameAnalyticsDest("source-1-alias", null))
.setAnalysis(new OutlierDetection())

View File

@ -43,7 +43,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.addAggregatableField("some_float", "float").build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), RESULTS_FIELD, false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<ExtractedField> allFields = extractedFields.getAllFields();
@ -58,7 +58,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), RESULTS_FIELD, false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<ExtractedField> allFields = extractedFields.getAllFields();
@ -72,7 +72,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.addAggregatableField("some_keyword", "keyword").build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), RESULTS_FIELD, false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]"));
@ -83,7 +83,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.addAggregatableField("indecisive_field", "float", "keyword").build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), RESULTS_FIELD, false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]"));
@ -97,7 +97,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), RESULTS_FIELD, false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<ExtractedField> allFields = extractedFields.getAllFields();
@ -117,7 +117,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities);
SOURCE_INDEX, buildRegressionConfig("foo"), RESULTS_FIELD, false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<ExtractedField> allFields = extractedFields.getAllFields();
@ -136,7 +136,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities);
SOURCE_INDEX, buildRegressionConfig("foo"), RESULTS_FIELD, false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("required fields [foo] are missing"));
@ -147,7 +147,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.addAggregatableField("_id", "float").build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), RESULTS_FIELD, false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]"));
@ -169,7 +169,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
FieldCapabilitiesResponse fieldCapabilities = mockFieldCapsResponseBuilder.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), RESULTS_FIELD, false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -186,7 +186,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your_field1", "my*"}, new String[0]);
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), RESULTS_FIELD, false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("No field [your_field1] could be detected"));
@ -201,7 +201,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
FetchSourceContext desiredFields = new FetchSourceContext(true, new String[0], new String[]{"my_*"});
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), RESULTS_FIELD, false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]"));
}
@ -217,7 +217,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your*", "my_*"}, new String[]{"*nope"});
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), RESULTS_FIELD, false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -234,7 +234,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), RESULTS_FIELD, false, 100, fieldCapabilities);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
assertThat(e.getMessage(), equalTo("A field that matches the dest.results_field [ml] already exists; " +
@ -250,7 +250,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(), true, 100, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -258,6 +258,23 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
assertThat(extractedFieldNames, equalTo(Arrays.asList("my_field1", "your_field2")));
}
public void testDetectedExtractedFields_NullResultsField() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField(RESULTS_FIELD, "float")
.addAggregatableField("my_field1", "float")
.addAggregatableField("your_field2", "float")
.addAggregatableField("your_keyword", "keyword")
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(), null, false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, equalTo(Arrays.asList(RESULTS_FIELD, "my_field1", "your_field2")));
}
public void testDetectedExtractedFields_GivenLessFieldsThanDocValuesLimit() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("field_1", "float")
@ -267,7 +284,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(), true, 4, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), RESULTS_FIELD, true, 4, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -286,7 +303,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(), true, 3, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), RESULTS_FIELD, true, 3, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -305,7 +322,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
.build();
ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildOutlierDetectionConfig(), true, 2, fieldCapabilities);
SOURCE_INDEX, buildOutlierDetectionConfig(), RESULTS_FIELD, true, 2, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
@ -320,9 +337,10 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
}
private static DataFrameAnalyticsConfig buildOutlierDetectionConfig(FetchSourceContext analyzedFields) {
return new DataFrameAnalyticsConfig.Builder("foo")
return new DataFrameAnalyticsConfig.Builder()
.setId("foo")
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null))
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, null))
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, RESULTS_FIELD))
.setAnalyzedFields(analyzedFields)
.setAnalysis(new OutlierDetection())
.build();
@ -333,9 +351,10 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
}
private static DataFrameAnalyticsConfig buildRegressionConfig(String dependentVariable, FetchSourceContext analyzedFields) {
return new DataFrameAnalyticsConfig.Builder("foo")
return new DataFrameAnalyticsConfig.Builder()
.setId("foo")
.setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null))
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, null))
.setDest(new DataFrameAnalyticsDest(DEST_INDEX, RESULTS_FIELD))
.setAnalyzedFields(analyzedFields)
.setAnalysis(new Regression(dependentVariable))
.build();

View File

@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.apache.lucene.util.LuceneTestCase;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.ProcessPipes;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.not;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
public class AnalyticsBuilderTests extends ESTestCase {
private NativeController nativeController;
private ProcessPipes processPipes;
private AnalyticsProcessConfig config;
private List<Path> filesToDelete;
private ArgumentCaptor<List<String>> commandCaptor;
private AnalyticsBuilder analyticsBuilder;
@SuppressWarnings("unchecked")
@Before
public void setUpMocks() {
nativeController = mock(NativeController.class);
processPipes = mock(ProcessPipes.class);
config = mock(AnalyticsProcessConfig.class);
filesToDelete = new ArrayList<>();
commandCaptor = ArgumentCaptor.forClass((Class) List.class);
analyticsBuilder = new AnalyticsBuilder(LuceneTestCase::createTempDir, nativeController, processPipes, config, filesToDelete);
}
public void testBuild_Analytics() throws Exception {
analyticsBuilder.build();
assertThat(filesToDelete, hasSize(1));
verify(nativeController).startProcess(commandCaptor.capture());
verifyNoMoreInteractions(nativeController);
List<String> command = commandCaptor.getValue();
assertThat(command, not(hasItem("--memoryUsageEstimationOnly")));
}
public void testBuild_MemoryUsageEstimation() throws Exception {
analyticsBuilder
.performMemoryUsageEstimationOnly()
.build();
assertThat(filesToDelete, hasSize(1));
verify(nativeController).startProcess(commandCaptor.capture());
verifyNoMoreInteractions(nativeController);
List<String> command = commandCaptor.getValue();
assertThat(command, hasItem("--memoryUsageEstimationOnly"));
}
}

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.junit.Before;
import org.mockito.InOrder;
@ -25,12 +26,13 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
private static final String JOB_ID = "analytics-result-processor-tests";
private AnalyticsProcess process;
private AnalyticsProcess<AnalyticsResult> process;
private DataFrameRowsJoiner dataFrameRowsJoiner;
private int progressPercent;
@Before
@SuppressWarnings("unchecked")
public void setUpMocks() {
process = mock(AnalyticsProcess.class);
dataFrameRowsJoiner = mock(DataFrameRowsJoiner.class);

View File

@ -0,0 +1,183 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyBoolean;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
public class MemoryUsageEstimationProcessManagerTests extends ESTestCase {
private static final String TASK_ID = "mem_est_123";
private static final String CONFIG_ID = "dummy";
private static final int NUM_ROWS = 100;
private static final int NUM_COLS = 4;
private static final MemoryUsageEstimationResult PROCESS_RESULT_ZERO =
new MemoryUsageEstimationResult(ByteSizeValue.ZERO, ByteSizeValue.ZERO);
private static final MemoryUsageEstimationResult PROCESS_RESULT =
new MemoryUsageEstimationResult(ByteSizeValue.parseBytesSizeValue("20kB", ""), ByteSizeValue.parseBytesSizeValue("10kB", ""));
private ExecutorService executorServiceForJob;
private ExecutorService executorServiceForProcess;
private AnalyticsProcess<MemoryUsageEstimationResult> process;
private AnalyticsProcessFactory<MemoryUsageEstimationResult> processFactory;
private DataFrameDataExtractor dataExtractor;
private DataFrameDataExtractorFactory dataExtractorFactory;
private DataFrameAnalyticsConfig dataFrameAnalyticsConfig;
private ActionListener<MemoryUsageEstimationResult> listener;
private ArgumentCaptor<MemoryUsageEstimationResult> resultCaptor;
private ArgumentCaptor<Exception> exceptionCaptor;
private MemoryUsageEstimationProcessManager processManager;
@SuppressWarnings("unchecked")
@Before
public void setUpMocks() {
executorServiceForJob = EsExecutors.newDirectExecutorService();
executorServiceForProcess = mock(ExecutorService.class);
process = mock(AnalyticsProcess.class);
when(process.isProcessAlive()).thenReturn(true);
when(process.readAnalyticsResults()).thenReturn(Arrays.asList(PROCESS_RESULT).iterator());
processFactory = mock(AnalyticsProcessFactory.class);
when(processFactory.createAnalyticsProcess(anyString(), any(), any(), any())).thenReturn(process);
dataExtractor = mock(DataFrameDataExtractor.class);
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);
when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor);
dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandom(CONFIG_ID);
listener = mock(ActionListener.class);
resultCaptor = ArgumentCaptor.forClass(MemoryUsageEstimationResult.class);
exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
processManager = new MemoryUsageEstimationProcessManager(executorServiceForJob, executorServiceForProcess, processFactory);
}
public void testRunJob_EmptyDataFrame() {
when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(0, NUM_COLS));
processManager.runJobAsync(TASK_ID, dataFrameAnalyticsConfig, dataExtractorFactory, listener);
verify(listener).onResponse(resultCaptor.capture());
MemoryUsageEstimationResult result = resultCaptor.getValue();
assertThat(result, equalTo(PROCESS_RESULT_ZERO));
verifyNoMoreInteractions(process, listener);
}
public void testRunJob_ProcessNotAlive() {
when(process.isProcessAlive()).thenReturn(false);
processManager.runJobAsync(TASK_ID, dataFrameAnalyticsConfig, dataExtractorFactory, listener);
verify(listener).onFailure(exceptionCaptor.capture());
ElasticsearchException exception = (ElasticsearchException) exceptionCaptor.getValue();
assertThat(exception.status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
assertThat(exception.getMessage(), containsString(TASK_ID));
assertThat(exception.getMessage(), containsString("Error while starting process"));
verify(process).isProcessAlive();
verifyNoMoreInteractions(process, listener);
}
public void testRunJob_NoResults() throws Exception {
when(process.readAnalyticsResults()).thenReturn(Arrays.<MemoryUsageEstimationResult>asList().iterator());
processManager.runJobAsync(TASK_ID, dataFrameAnalyticsConfig, dataExtractorFactory, listener);
verify(listener).onFailure(exceptionCaptor.capture());
ElasticsearchException exception = (ElasticsearchException) exceptionCaptor.getValue();
assertThat(exception.status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
assertThat(exception.getMessage(), containsString(TASK_ID));
assertThat(exception.getMessage(), containsString("no results"));
InOrder inOrder = inOrder(process);
inOrder.verify(process).isProcessAlive();
inOrder.verify(process).readAnalyticsResults();
inOrder.verify(process).consumeAndCloseOutputStream();
inOrder.verify(process).close();
verifyNoMoreInteractions(process, listener);
}
public void testRunJob_MultipleResults() throws Exception {
when(process.readAnalyticsResults()).thenReturn(Arrays.asList(PROCESS_RESULT, PROCESS_RESULT).iterator());
processManager.runJobAsync(TASK_ID, dataFrameAnalyticsConfig, dataExtractorFactory, listener);
verify(listener).onFailure(exceptionCaptor.capture());
ElasticsearchException exception = (ElasticsearchException) exceptionCaptor.getValue();
assertThat(exception.status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
assertThat(exception.getMessage(), containsString(TASK_ID));
assertThat(exception.getMessage(), containsString("more than one result"));
InOrder inOrder = inOrder(process);
inOrder.verify(process).isProcessAlive();
inOrder.verify(process).readAnalyticsResults();
inOrder.verify(process).consumeAndCloseOutputStream();
inOrder.verify(process).close();
verifyNoMoreInteractions(process, listener);
}
public void testRunJob_FailsOnClose() throws Exception {
doThrow(ExceptionsHelper.serverError("some LOG(ERROR) lines coming from cpp process")).when(process).close();
processManager.runJobAsync(TASK_ID, dataFrameAnalyticsConfig, dataExtractorFactory, listener);
verify(listener).onFailure(exceptionCaptor.capture());
ElasticsearchException exception = (ElasticsearchException) exceptionCaptor.getValue();
assertThat(exception.status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
assertThat(exception.getMessage(), containsString(TASK_ID));
assertThat(exception.getMessage(), containsString("Error while closing process"));
InOrder inOrder = inOrder(process);
inOrder.verify(process).isProcessAlive();
inOrder.verify(process).readAnalyticsResults();
inOrder.verify(process).consumeAndCloseOutputStream();
inOrder.verify(process).close();
verifyNoMoreInteractions(process, listener);
}
public void testRunJob_Ok() throws Exception {
processManager.runJobAsync(TASK_ID, dataFrameAnalyticsConfig, dataExtractorFactory, listener);
verify(listener).onResponse(resultCaptor.capture());
MemoryUsageEstimationResult result = resultCaptor.getValue();
assertThat(result, equalTo(PROCESS_RESULT));
InOrder inOrder = inOrder(process);
inOrder.verify(process).isProcessAlive();
inOrder.verify(process).readAnalyticsResults();
inOrder.verify(process).consumeAndCloseOutputStream();
inOrder.verify(process).close();
verifyNoMoreInteractions(process, listener);
}
}

View File

@ -3,12 +3,10 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process;
package org.elasticsearch.xpack.ml.dataframe.process.results;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResultsTests;
import java.io.IOException;

View File

@ -0,0 +1,51 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.results;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import java.io.IOException;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;
public class MemoryUsageEstimationResultTests extends AbstractXContentTestCase<MemoryUsageEstimationResult> {
public static MemoryUsageEstimationResult createRandomResult() {
return new MemoryUsageEstimationResult(
randomBoolean() ? new ByteSizeValue(randomNonNegativeLong()) : null,
randomBoolean() ? new ByteSizeValue(randomNonNegativeLong()) : null);
}
@Override
protected MemoryUsageEstimationResult createTestInstance() {
return createRandomResult();
}
@Override
protected MemoryUsageEstimationResult doParseInstance(XContentParser parser) throws IOException {
return MemoryUsageEstimationResult.PARSER.apply(parser, null);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
public void testConstructor_NullValues() {
MemoryUsageEstimationResult result = new MemoryUsageEstimationResult(null, null);
assertThat(result.getExpectedMemoryUsageWithOnePartition(), nullValue());
assertThat(result.getExpectedMemoryUsageWithMaxPartitions(), nullValue());
}
public void testConstructor() {
MemoryUsageEstimationResult result = new MemoryUsageEstimationResult(new ByteSizeValue(2048), new ByteSizeValue(1024));
assertThat(result.getExpectedMemoryUsageWithOnePartition(), equalTo(new ByteSizeValue(2048)));
assertThat(result.getExpectedMemoryUsageWithMaxPartitions(), equalTo(new ByteSizeValue(1024)));
}
}

View File

@ -0,0 +1,155 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.process;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import static org.hamcrest.Matchers.is;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
public class AbstractNativeProcessTests extends ESTestCase {
private InputStream logStream;
private OutputStream inputStream;
private InputStream outputStream;
private OutputStream restoreStream;
private Consumer<String> onProcessCrash;
private ExecutorService executorService;
private CountDownLatch wait = new CountDownLatch(1);
@Before
@SuppressWarnings("unchecked")
public void initialize() throws IOException {
logStream = mock(InputStream.class);
// This answer blocks the thread on the executor service.
// In order to unblock it, the test needs to call wait.countDown().
when(logStream.read(new byte[1024])).thenAnswer(
invocationOnMock -> {
wait.await();
return -1;
});
inputStream = mock(OutputStream.class);
outputStream = mock(InputStream.class);
when(outputStream.read(new byte[512])).thenReturn(-1);
restoreStream = mock(OutputStream.class);
onProcessCrash = mock(Consumer.class);
executorService = EsExecutors.newFixed("test", 1, 1, EsExecutors.daemonThreadFactory("test"), new ThreadContext(Settings.EMPTY));
}
@After
public void terminateExecutorService() {
ThreadPool.terminate(executorService, 10, TimeUnit.SECONDS);
verifyNoMoreInteractions(onProcessCrash);
}
public void testStart_DoNotDetectCrashWhenNoInputPipeProvided() throws Exception {
try (AbstractNativeProcess process = new TestNativeProcess(null)) {
process.start(executorService);
wait.countDown();
}
}
public void testStart_DoNotDetectCrashWhenProcessIsBeingClosed() throws Exception {
try (AbstractNativeProcess process = new TestNativeProcess(inputStream)) {
process.start(executorService);
process.close();
wait.countDown();
}
}
public void testStart_DoNotDetectCrashWhenProcessIsBeingKilled() throws Exception {
try (AbstractNativeProcess process = new TestNativeProcess(inputStream)) {
process.start(executorService);
process.kill();
wait.countDown();
}
}
public void testStart_DetectCrashWhenInputPipeExists() throws Exception {
try (AbstractNativeProcess process = new TestNativeProcess(inputStream)) {
process.start(executorService);
wait.countDown();
ThreadPool.terminate(executorService, 10, TimeUnit.SECONDS);
verify(onProcessCrash).accept("[foo] test process stopped unexpectedly: ");
}
}
public void testWriteRecord() throws Exception {
try (AbstractNativeProcess process = new TestNativeProcess(inputStream)) {
process.writeRecord(new String[] {"a", "b", "c"});
process.flushStream();
verify(inputStream).write(any(), anyInt(), anyInt());
}
}
public void testWriteRecord_FailWhenNoInputPipeProvided() throws Exception {
try (AbstractNativeProcess process = new TestNativeProcess(null)) {
expectThrows(NullPointerException.class, () -> process.writeRecord(new String[] {"a", "b", "c"}));
}
}
public void testFlush() throws Exception {
try (AbstractNativeProcess process = new TestNativeProcess(inputStream)) {
process.flushStream();
verify(inputStream).flush();
}
}
public void testFlush_FailWhenNoInputPipeProvided() throws Exception {
try (AbstractNativeProcess process = new TestNativeProcess(null)) {
expectThrows(NullPointerException.class, () -> process.flushStream());
}
}
public void testIsReady() throws Exception {
try (AbstractNativeProcess process = new TestNativeProcess(null)) {
assertThat(process.isReady(), is(false));
process.setReady();
assertThat(process.isReady(), is(true));
}
}
/**
* Mock-based implementation of {@link AbstractNativeProcess}.
*/
private class TestNativeProcess extends AbstractNativeProcess {
TestNativeProcess(OutputStream inputStream) {
super("foo", logStream, inputStream, outputStream, restoreStream, 0, null, onProcessCrash);
}
@Override
public String getName() {
return "test";
}
@Override
public void persistState() throws IOException {
}
}
}

View File

@ -0,0 +1,16 @@
{
"ml.estimate_memory_usage": {
"documentation": "http://www.elastic.co/guide/en/elasticsearch/reference/current/estimate-memory-usage-dfanalytics.html",
"stability": "experimental",
"methods": [ "POST" ],
"url": {
"path": "/_ml/data_frame/analytics/_estimate_memory_usage",
"paths": [ "/_ml/data_frame/analytics/_estimate_memory_usage" ],
"parts": {}
},
"body": {
"description" : "Memory usage estimation definition",
"required" : true
}
}
}

View File

@ -0,0 +1,75 @@
---
setup:
- do:
indices.create:
index: index-source
body:
mappings:
properties:
x:
type: float
y:
type: float
---
"Test memory usage estimation for empty data frame":
- do:
ml.estimate_memory_usage:
body:
data_frame_analytics_config:
source: { index: "index-source" }
analysis: { outlier_detection: {} }
- match: { expected_memory_usage_with_one_partition: "0" }
- match: { expected_memory_usage_with_max_partitions: "0" }
---
"Test memory usage estimation for non-empty data frame":
- do:
index:
index: index-source
refresh: true
body: { x: 1, y: 10 }
- match: { result: "created" }
- do:
ml.estimate_memory_usage:
body:
data_frame_analytics_config:
source: { index: "index-source" }
analysis: { outlier_detection: {} }
- match: { expected_memory_usage_with_one_partition: "3kb" }
- match: { expected_memory_usage_with_max_partitions: "3kb" }
- do:
index:
index: index-source
refresh: true
body: { x: 2, y: 20 }
- match: { result: "created" }
- do:
ml.estimate_memory_usage:
body:
data_frame_analytics_config:
source: { index: "index-source" }
analysis: { outlier_detection: {} }
- match: { expected_memory_usage_with_one_partition: "4kb" }
- match: { expected_memory_usage_with_max_partitions: "4kb" }
- do:
index:
index: index-source
refresh: true
body: { x: 3, y: 30 }
- match: { result: "created" }
- do:
ml.estimate_memory_usage:
body:
data_frame_analytics_config:
source: { index: "index-source" }
analysis: { outlier_detection: {} }
- match: { expected_memory_usage_with_one_partition: "6kb" }
- match: { expected_memory_usage_with_max_partitions: "5kb" }