diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java index 7f640e5eeab..ded1e0491de 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.ml.integration; import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; @@ -13,6 +14,7 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; @@ -22,6 +24,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.lessThanOrEqualTo; @@ -127,6 +131,49 @@ public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsInteg lessThanOrEqualTo(allDataUsedForTraining)); } + public void testSimultaneousExplainSameConfig() throws IOException { + + final int simultaneousInvocationCount = 10; + + String sourceIndex = "test-simultaneous-explain"; + RegressionIT.indexData(sourceIndex, 100, 0); + + DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder() + .setId("dfa-simultaneous-explain-" + sourceIndex) + .setSource(new DataFrameAnalyticsSource(new String[]{sourceIndex}, + QueryProvider.fromParsedQuery(QueryBuilders.matchAllQuery()), + null)) + .setAnalysis(new Regression(RegressionIT.DEPENDENT_VARIABLE_FIELD, + BoostedTreeParams.builder().build(), + null, + 100.0, + null, + null, + null)) + .buildForExplain(); + + List> futures = new ArrayList<>(); + + for (int i = 0; i < simultaneousInvocationCount; ++i) { + futures.add(client().execute(ExplainDataFrameAnalyticsAction.INSTANCE, new PutDataFrameAnalyticsAction.Request(config))); + } + + ExplainDataFrameAnalyticsAction.Response previous = null; + for (ActionFuture future : futures) { + // The main purpose of this test is that actionGet() here will throw an exception + // if any of the simultaneous calls returns an error due to interaction between + // the many estimation processes that get run + ExplainDataFrameAnalyticsAction.Response current = future.actionGet(10000); + if (previous != null) { + // A secondary check the test can perform is that the multiple invocations + // return the same result (but it was failures due to unwanted interactions + // that caused this test to be written) + assertEquals(previous, current); + } + previous = current; + } + } + @Override boolean supportsInference() { return false; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcessFactory.java index 86c314ef5de..839268c3241 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcessFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcessFactory.java @@ -29,6 +29,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; public class NativeMemoryUsageEstimationProcessFactory implements AnalyticsProcessFactory { @@ -39,11 +40,13 @@ public class NativeMemoryUsageEstimationProcessFactory implements AnalyticsProce private final Environment env; private final NativeController nativeController; + private final AtomicLong counter; private volatile Duration processConnectTimeout; public NativeMemoryUsageEstimationProcessFactory(Environment env, NativeController nativeController, ClusterService clusterService) { this.env = Objects.requireNonNull(env); this.nativeController = Objects.requireNonNull(nativeController); + this.counter = new AtomicLong(0); setProcessConnectTimeout(MachineLearning.PROCESS_CONNECT_TIMEOUT.get(env.settings())); clusterService.getClusterSettings().addSettingsUpdateConsumer( MachineLearning.PROCESS_CONNECT_TIMEOUT, this::setProcessConnectTimeout); @@ -61,8 +64,13 @@ public class NativeMemoryUsageEstimationProcessFactory implements AnalyticsProce ExecutorService executorService, Consumer onProcessCrash) { List filesToDelete = new ArrayList<>(); + // The config ID passed to the process pipes is only used to make the file names unique. Since memory estimation can be + // called many times in quick succession for the same config the config ID alone is not sufficient to guarantee that the + // memory estimation process pipe names are unique. Therefore an increasing counter value is appended to the config ID + // to ensure uniqueness between calls. ProcessPipes processPipes = new ProcessPipes( - env, NAMED_PIPE_HELPER, AnalyticsBuilder.ANALYTICS, config.getId(), false, false, true, false, false); + env, NAMED_PIPE_HELPER, AnalyticsBuilder.ANALYTICS, config.getId() + "_" + counter.incrementAndGet(), + false, false, true, false, false); createNativeProcess(config.getId(), analyticsProcessConfig, filesToDelete, processPipes);