diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java index e78c6015ec1..8688bc32ee0 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.ml.integration; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.admin.indices.settings.get.GetSettingsRequest; import org.elasticsearch.action.admin.indices.settings.get.GetSettingsResponse; import org.elasticsearch.action.bulk.BulkRequestBuilder; @@ -13,12 +14,18 @@ import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.junit.After; import java.util.Arrays; @@ -33,6 +40,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.hamcrest.Matchers.startsWith; public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTestCase { @@ -430,4 +438,45 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest } assertThat(resultsWithPrediction, greaterThan(0)); } + + public void testModelMemoryLimitLowerThanEstimatedMemoryUsage() { + String sourceIndex = "test-model-memory-limit"; + + client().admin().indices().prepareCreate(sourceIndex) + .addMapping("_doc", "col_1", "type=double", "col_2", "type=float", "col_3", "type=keyword") + .get(); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + for (int i = 0; i < 10000; i++) { // This number of rows should make memory usage estimate greater than 1MB + IndexRequest indexRequest = new IndexRequest(sourceIndex) + .id("doc_" + i) + .source("col_1", 1.0, "col_2", 1.0, "col_3", "str"); + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + + String id = "test_model_memory_limit_lower_than_estimated_memory_usage"; + ByteSizeValue modelMemoryLimit = new ByteSizeValue(1, ByteSizeUnit.MB); + DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder() + .setId(id) + .setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null)) + .setDest(new DataFrameAnalyticsDest(sourceIndex + "-results", null)) + .setAnalysis(new OutlierDetection()) + .setModelMemoryLimit(modelMemoryLimit) + .build(); + + registerAnalytics(config); + putAnalytics(config); + assertState(id, DataFrameAnalyticsState.STOPPED); + + ElasticsearchStatusException exception = expectThrows(ElasticsearchStatusException.class, () -> startAnalytics(id)); + assertThat(exception.status(), equalTo(RestStatus.BAD_REQUEST)); + assertThat( + exception.getMessage(), + startsWith("Cannot start because the configured model memory limit [" + modelMemoryLimit + + "] is lower than the expected memory usage")); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index 8835de2e228..676e56a852a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -51,6 +51,7 @@ import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction; import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; @@ -168,11 +169,36 @@ public class TransportStartDataFrameAnalyticsAction ); // Tell the job tracker to refresh the memory requirement for this job and all other jobs that have persistent tasks + ActionListener estimateMemoryUsageListener = ActionListener.wrap( + estimateMemoryUsageResponse -> { + // Validate that model memory limit is sufficient to run the analysis + if (configHolder.get().getModelMemoryLimit() + .compareTo(estimateMemoryUsageResponse.getExpectedMemoryUsageWithOnePartition()) < 0) { + ElasticsearchStatusException e = + ExceptionsHelper.badRequestException( + "Cannot start because the configured model memory limit [{}] is lower than the expected memory usage [{}]", + configHolder.get().getModelMemoryLimit(), estimateMemoryUsageResponse.getExpectedMemoryUsageWithOnePartition()); + listener.onFailure(e); + return; + } + // Refresh memory requirement for jobs + memoryTracker.addDataFrameAnalyticsJobMemoryAndRefreshAllOthers( + request.getId(), configHolder.get().getModelMemoryLimit().getBytes(), memoryRequirementRefreshListener); + }, + listener::onFailure + ); + + // Perform memory usage estimation for this config ActionListener configListener = ActionListener.wrap( config -> { configHolder.set(config); - memoryTracker.addDataFrameAnalyticsJobMemoryAndRefreshAllOthers( - request.getId(), config.getModelMemoryLimit().getBytes(), memoryRequirementRefreshListener); + EstimateMemoryUsageAction.Request estimateMemoryUsageRequest = new EstimateMemoryUsageAction.Request(config); + ClientHelper.executeAsyncWithOrigin( + client, + ClientHelper.ML_ORIGIN, + EstimateMemoryUsageAction.INSTANCE, + estimateMemoryUsageRequest, + estimateMemoryUsageListener); }, listener::onFailure ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManager.java index 17595db791e..fac084c0fc8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManager.java @@ -9,6 +9,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -66,7 +67,9 @@ public class MemoryUsageEstimationProcessManager { new AnalyticsProcessConfig( dataSummary.rows, dataSummary.cols, - DataFrameAnalyticsConfig.MIN_MODEL_MEMORY_LIMIT, + // For memory estimation the model memory limit here should be set high enough not to trigger an error when C++ code + // compares the limit to the result of estimation. + new ByteSizeValue(1, ByteSizeUnit.PB), 1, "", categoricalFields,