From 168b566844b1c83d6a895b48d6edab783b02d57d Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 19 Jun 2017 18:39:45 +0100 Subject: [PATCH] [ML] Allow closing a job that is restoring state (elastic/x-pack-elasticsearch#1746) This change enables closing a job while it is in the middle of restoring its state. This is has the benefit of allowing users to close jobs that due to relocation are `opened` but they are still restoring state. It also helps avoiding race conditions in tests. Part of this change also includes restoring the state as a separate step from the process creation. This means we no longer block the job map while the process is restoring its state. relates elastic/x-pack-elasticsearch#1270 Original commit: elastic/x-pack-elasticsearch@1713a4a7c4a6c458fad65d497483babc54db7218 --- .../xpack/ml/action/CloseJobAction.java | 5 +- .../xpack/ml/action/OpenJobAction.java | 24 ++-- .../xpack/ml/job/persistence/JobProvider.java | 70 ---------- .../ml/job/persistence/StateStreamer.java | 123 ++++++++++++++++++ .../autodetect/AutodetectCommunicator.java | 25 +++- .../process/autodetect/AutodetectProcess.java | 16 +++ .../autodetect/AutodetectProcessManager.java | 7 +- .../BlackHoleAutodetectProcess.java | 11 ++ .../autodetect/NativeAutodetectProcess.java | 32 ++++- .../NativeAutodetectProcessFactory.java | 14 +- .../ml/action/CloseJobActionRequestTests.java | 5 +- .../ml/job/persistence/JobProviderTests.java | 41 ------ .../job/persistence/StateStreamerTests.java | 86 ++++++++++++ .../AutodetectCommunicatorTests.java | 32 ++++- .../NativeAutodetectProcessTests.java | 10 +- 15 files changed, 339 insertions(+), 162 deletions(-) create mode 100644 plugin/src/main/java/org/elasticsearch/xpack/ml/job/persistence/StateStreamer.java create mode 100644 plugin/src/test/java/org/elasticsearch/xpack/ml/job/persistence/StateStreamerTests.java diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/action/CloseJobAction.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/action/CloseJobAction.java index 8388976c5c7..94d83f01e53 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/action/CloseJobAction.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/action/CloseJobAction.java @@ -616,6 +616,7 @@ public class CloseJobAction extends Action datafeed = mlMetadata.getDatafeedByJobId(jobId); if (datafeed.isPresent()) { DatafeedState datafeedState = MlMetadata.getDatafeedState(datafeed.get().getId(), tasks); diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/action/OpenJobAction.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/action/OpenJobAction.java index 5a30ebf780b..cf149438057 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/action/OpenJobAction.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/action/OpenJobAction.java @@ -55,6 +55,7 @@ import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.MlMetaIndex; import org.elasticsearch.xpack.ml.MlMetadata; import org.elasticsearch.xpack.ml.job.config.Job; +import org.elasticsearch.xpack.ml.job.config.JobState; import org.elasticsearch.xpack.ml.job.config.JobTaskStatus; import org.elasticsearch.xpack.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.ml.job.persistence.ElasticsearchMappings; @@ -521,22 +522,25 @@ public class OpenJobAction extends Action persistentTask) { - if (persistentTask == null) { - return false; + JobState jobState = JobState.CLOSED; + if (persistentTask != null) { + JobTaskStatus jobStateStatus = (JobTaskStatus) persistentTask.getStatus(); + jobState = jobStateStatus == null ? JobState.OPENING : jobStateStatus.getState(); } - JobTaskStatus jobState = (JobTaskStatus) persistentTask.getStatus(); - if (jobState == null) { - return false; - } - switch (jobState.getState()) { + switch (jobState) { + case OPENING: + case CLOSED: + return false; case OPENED: opened = true; return true; + case CLOSING: + throw ExceptionsHelper.conflictStatusException("The job has been " + JobState.CLOSED + " while waiting to be " + + JobState.OPENED); case FAILED: - return true; default: - throw new IllegalStateException("Unexpected job state [" + jobState.getState() + "]"); - + throw new IllegalStateException("Unexpected job state [" + jobState + + "] while waiting for job to be " + JobState.OPENED); } } } diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobProvider.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobProvider.java index 8cc2d27d049..2178e94243b 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobProvider.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobProvider.java @@ -6,8 +6,6 @@ package org.elasticsearch.xpack.ml.job.persistence; import org.apache.logging.log4j.Logger; -import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.BytesRefIterator; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.ElasticsearchStatusException; @@ -16,7 +14,6 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.indices.create.CreateIndexRequest; import org.elasticsearch.action.admin.indices.mapping.put.PutMappingResponse; -import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.search.MultiSearchRequestBuilder; import org.elasticsearch.action.search.MultiSearchResponse; import org.elasticsearch.action.search.SearchRequest; @@ -76,7 +73,6 @@ import org.elasticsearch.xpack.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.security.support.Exceptions; import java.io.IOException; -import java.io.OutputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -837,72 +833,6 @@ public class JobProvider { }, errorHandler)); } - /** - * Given a model snapshot, get the corresponding state and write it to the supplied - * stream. If there are multiple state documents they are separated using '\0' - * when written to the stream. - * - * Because we have a rule that we will not open a legacy job in the current product version - * we don't have to worry about legacy document IDs here. - * - * @param jobId the job id - * @param modelSnapshot the model snapshot to be restored - * @param restoreStream the stream to write the state to - */ - public void restoreStateToStream(String jobId, ModelSnapshot modelSnapshot, OutputStream restoreStream) throws IOException { - String indexName = AnomalyDetectorsIndex.jobStateIndexName(); - - // First try to restore model state. - for (String stateDocId : modelSnapshot.stateDocumentIds()) { - LOGGER.trace("ES API CALL: get ID {} from index {}", stateDocId, indexName); - - GetResponse stateResponse = client.prepareGet(indexName, ElasticsearchMappings.DOC_TYPE, stateDocId).get(); - if (!stateResponse.isExists()) { - LOGGER.error("Expected {} documents for model state for {} snapshot {} but failed to find {}", - modelSnapshot.getSnapshotDocCount(), jobId, modelSnapshot.getSnapshotId(), stateDocId); - break; - } - writeStateToStream(stateResponse.getSourceAsBytesRef(), restoreStream); - } - - // Secondly try to restore categorizer state. This must come after model state because that's - // the order the C++ process expects. There are no snapshots for this, so the IDs simply - // count up until a document is not found. It's NOT an error to have no categorizer state. - int docNum = 0; - while (true) { - String docId = CategorizerState.documentId(jobId, ++docNum); - - LOGGER.trace("ES API CALL: get ID {} from index {}", docId, indexName); - - GetResponse stateResponse = client.prepareGet(indexName, ElasticsearchMappings.DOC_TYPE, docId).get(); - if (!stateResponse.isExists()) { - break; - } - writeStateToStream(stateResponse.getSourceAsBytesRef(), restoreStream); - } - - } - - private void writeStateToStream(BytesReference source, OutputStream stream) throws IOException { - // The source bytes are already UTF-8. The C++ process wants UTF-8, so we - // can avoid converting to a Java String only to convert back again. - BytesRefIterator iterator = source.iterator(); - for (BytesRef ref = iterator.next(); ref != null; ref = iterator.next()) { - // There's a complication that the source can already have trailing 0 bytes - int length = ref.bytes.length; - while (length > 0 && ref.bytes[length - 1] == 0) { - --length; - } - if (length > 0) { - stream.write(ref.bytes, 0, length); - } - } - // This is dictated by RapidJSON on the C++ side; it treats a '\0' as end-of-file - // even when it's not really end-of-file, and this is what we need because we're - // sending multiple JSON documents via the same named pipe. - stream.write(0); - } - public QueryPage modelPlot(String jobId, int from, int size) { SearchResponse searchResponse; String indexName = AnomalyDetectorsIndex.jobResultsAliasedName(jobId); diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/persistence/StateStreamer.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/persistence/StateStreamer.java new file mode 100644 index 00000000000..81b818e0ca2 --- /dev/null +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/persistence/StateStreamer.java @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.job.persistence; + +import org.apache.logging.log4j.Logger; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefIterator; +import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.logging.Loggers; +import org.elasticsearch.xpack.ml.job.process.autodetect.state.CategorizerState; +import org.elasticsearch.xpack.ml.job.process.autodetect.state.ModelSnapshot; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.Objects; + +/** + * A {@code StateStreamer} fetches the various state documents and + * writes them into a stream. It allows cancellation via its + *{@link #cancel()} method; cancellation is checked between writing + * the various state documents. + */ +public class StateStreamer { + + private static final Logger LOGGER = Loggers.getLogger(StateStreamer.class); + + private final Client client; + private volatile boolean isCancelled; + + public StateStreamer(Client client) { + this.client = Objects.requireNonNull(client); + } + + /** + * Cancels the state streaming at the first opportunity. + */ + public void cancel() { + isCancelled = true; + } + + /** + * Given a model snapshot, get the corresponding state and write it to the supplied + * stream. If there are multiple state documents they are separated using '\0' + * when written to the stream. + * + * Because we have a rule that we will not open a legacy job in the current product version + * we don't have to worry about legacy document IDs here. + * + * @param jobId the job id + * @param modelSnapshot the model snapshot to be restored + * @param restoreStream the stream to write the state to + */ + public void restoreStateToStream(String jobId, ModelSnapshot modelSnapshot, OutputStream restoreStream) throws IOException { + String indexName = AnomalyDetectorsIndex.jobStateIndexName(); + + // First try to restore model state. + for (String stateDocId : modelSnapshot.stateDocumentIds()) { + if (isCancelled) { + return; + } + + LOGGER.trace("ES API CALL: get ID {} from index {}", stateDocId, indexName); + + GetResponse stateResponse = client.prepareGet(indexName, ElasticsearchMappings.DOC_TYPE, stateDocId).get(); + if (!stateResponse.isExists()) { + LOGGER.error("Expected {} documents for model state for {} snapshot {} but failed to find {}", + modelSnapshot.getSnapshotDocCount(), jobId, modelSnapshot.getSnapshotId(), stateDocId); + break; + } + writeStateToStream(stateResponse.getSourceAsBytesRef(), restoreStream); + } + + // Secondly try to restore categorizer state. This must come after model state because that's + // the order the C++ process expects. There are no snapshots for this, so the IDs simply + // count up until a document is not found. It's NOT an error to have no categorizer state. + int docNum = 0; + while (true) { + if (isCancelled) { + return; + } + + String docId = CategorizerState.documentId(jobId, ++docNum); + + LOGGER.trace("ES API CALL: get ID {} from index {}", docId, indexName); + + GetResponse stateResponse = client.prepareGet(indexName, ElasticsearchMappings.DOC_TYPE, docId).get(); + if (!stateResponse.isExists()) { + break; + } + writeStateToStream(stateResponse.getSourceAsBytesRef(), restoreStream); + } + + } + + private void writeStateToStream(BytesReference source, OutputStream stream) throws IOException { + if (isCancelled) { + return; + } + + // The source bytes are already UTF-8. The C++ process wants UTF-8, so we + // can avoid converting to a Java String only to convert back again. + BytesRefIterator iterator = source.iterator(); + for (BytesRef ref = iterator.next(); ref != null; ref = iterator.next()) { + // There's a complication that the source can already have trailing 0 bytes + int length = ref.bytes.length; + while (length > 0 && ref.bytes[length - 1] == 0) { + --length; + } + if (length > 0) { + stream.write(ref.bytes, 0, length); + } + } + // This is dictated by RapidJSON on the C++ side; it treats a '\0' as end-of-file + // even when it's not really end-of-file, and this is what we need because we're + // sending multiple JSON documents via the same named pipe. + stream.write(0); + } +} diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicator.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicator.java index 55e2f2cedeb..5d803e4bcdf 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicator.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicator.java @@ -19,6 +19,8 @@ import org.elasticsearch.xpack.ml.job.config.DataDescription; import org.elasticsearch.xpack.ml.job.config.Job; import org.elasticsearch.xpack.ml.job.config.JobUpdate; import org.elasticsearch.xpack.ml.job.config.ModelPlotConfig; +import org.elasticsearch.xpack.ml.job.persistence.JobProvider; +import org.elasticsearch.xpack.ml.job.persistence.StateStreamer; import org.elasticsearch.xpack.ml.job.process.CountingInputStream; import org.elasticsearch.xpack.ml.job.process.DataCountsReporter; import org.elasticsearch.xpack.ml.job.process.autodetect.output.AutoDetectResultProcessor; @@ -26,12 +28,14 @@ import org.elasticsearch.xpack.ml.job.process.autodetect.params.DataLoadParams; import org.elasticsearch.xpack.ml.job.process.autodetect.params.InterimResultsParams; import org.elasticsearch.xpack.ml.job.process.autodetect.state.DataCounts; import org.elasticsearch.xpack.ml.job.process.autodetect.state.ModelSizeStats; +import org.elasticsearch.xpack.ml.job.process.autodetect.state.ModelSnapshot; import org.elasticsearch.xpack.ml.job.process.autodetect.writer.DataToProcessWriter; import org.elasticsearch.xpack.ml.job.process.autodetect.writer.DataToProcessWriterFactory; import java.io.Closeable; import java.io.IOException; import java.io.InputStream; +import java.io.OutputStream; import java.time.Duration; import java.time.ZonedDateTime; import java.util.List; @@ -52,20 +56,23 @@ public class AutodetectCommunicator implements Closeable { private final Job job; private final JobTask jobTask; - private final DataCountsReporter dataCountsReporter; private final AutodetectProcess autodetectProcess; + private final StateStreamer stateStreamer; + private final DataCountsReporter dataCountsReporter; private final AutoDetectResultProcessor autoDetectResultProcessor; private final Consumer onFinishHandler; private final ExecutorService autodetectWorkerExecutor; private final NamedXContentRegistry xContentRegistry; private volatile boolean processKilled; - AutodetectCommunicator(Job job, JobTask jobTask, AutodetectProcess process, DataCountsReporter dataCountsReporter, - AutoDetectResultProcessor autoDetectResultProcessor, Consumer onFinishHandler, - NamedXContentRegistry xContentRegistry, ExecutorService autodetectWorkerExecutor) { + AutodetectCommunicator(Job job, JobTask jobTask, AutodetectProcess process, StateStreamer stateStreamer, + DataCountsReporter dataCountsReporter, AutoDetectResultProcessor autoDetectResultProcessor, + Consumer onFinishHandler, NamedXContentRegistry xContentRegistry, + ExecutorService autodetectWorkerExecutor) { this.job = job; this.jobTask = jobTask; this.autodetectProcess = process; + this.stateStreamer = stateStreamer; this.dataCountsReporter = dataCountsReporter; this.autoDetectResultProcessor = autoDetectResultProcessor; this.onFinishHandler = onFinishHandler; @@ -73,7 +80,8 @@ public class AutodetectCommunicator implements Closeable { this.autodetectWorkerExecutor = autodetectWorkerExecutor; } - public void writeJobInputHeader() throws IOException { + public void init(ModelSnapshot modelSnapshot) throws IOException { + autodetectProcess.restoreState(stateStreamer, modelSnapshot); createProcessWriter(Optional.empty()).writeHeader(); } @@ -129,7 +137,12 @@ public class AutodetectCommunicator implements Closeable { Future future = autodetectWorkerExecutor.submit(() -> { checkProcessIsAlive(); try { - autodetectProcess.close(); + if (autodetectProcess.isReady()) { + autodetectProcess.close(); + } else { + killProcess(false, false); + stateStreamer.cancel(); + } autoDetectResultProcessor.awaitCompletion(); } finally { onFinishHandler.accept(restart ? new ElasticsearchException(reason) : null); diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcess.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcess.java index b7372cbd3c1..7513d10542d 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcess.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcess.java @@ -7,9 +7,12 @@ package org.elasticsearch.xpack.ml.job.process.autodetect; import org.elasticsearch.xpack.ml.job.config.DetectionRule; import org.elasticsearch.xpack.ml.job.config.ModelPlotConfig; +import org.elasticsearch.xpack.ml.job.persistence.JobProvider; +import org.elasticsearch.xpack.ml.job.persistence.StateStreamer; import org.elasticsearch.xpack.ml.job.process.NativeController; import org.elasticsearch.xpack.ml.job.process.autodetect.params.DataLoadParams; import org.elasticsearch.xpack.ml.job.process.autodetect.params.InterimResultsParams; +import org.elasticsearch.xpack.ml.job.process.autodetect.state.ModelSnapshot; import org.elasticsearch.xpack.ml.job.results.AutodetectResult; import java.io.Closeable; @@ -23,6 +26,19 @@ import java.util.List; */ public interface AutodetectProcess extends Closeable { + /** + * Restore state from the given {@link ModelSnapshot} + * @param stateStreamer the streamer of the job state + * @param modelSnapshot the model snapshot to restore + */ + void restoreState(StateStreamer stateStreamer, ModelSnapshot modelSnapshot); + + /** + * Is the process ready to receive data? + * @return {@code true} if the process is ready to receive data + */ + boolean isReady(); + /** * Write the record to autodetect. The record parameter should not be encoded * (i.e. length encoded) the implementation will appy the corrrect encoding. diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java index 9f04b7d7477..fedeb43b419 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.ml.job.persistence.JobDataCountsPersister; import org.elasticsearch.xpack.ml.job.persistence.JobProvider; import org.elasticsearch.xpack.ml.job.persistence.JobRenormalizedResultsPersister; import org.elasticsearch.xpack.ml.job.persistence.JobResultsPersister; +import org.elasticsearch.xpack.ml.job.persistence.StateStreamer; import org.elasticsearch.xpack.ml.job.process.DataCountsReporter; import org.elasticsearch.xpack.ml.job.process.autodetect.output.AutoDetectResultProcessor; import org.elasticsearch.xpack.ml.job.process.autodetect.params.AutodetectParams; @@ -263,7 +264,7 @@ public class AutodetectProcessManager extends AbstractComponent { try { AutodetectCommunicator communicator = autoDetectCommunicatorByJob.computeIfAbsent(jobTask.getAllocationId(), id -> create(jobTask, params, ignoreDowntime, handler)); - communicator.writeJobInputHeader(); + communicator.init(params.modelSnapshot()); setJobState(jobTask, JobState.OPENED); } catch (Exception e1) { // No need to log here as the persistent task framework will log it @@ -338,8 +339,8 @@ public class AutodetectProcessManager extends AbstractComponent { } throw e; } - return new AutodetectCommunicator(job, jobTask, process, dataCountsReporter, processor, handler, xContentRegistry, - autodetectWorkerExecutor); + return new AutodetectCommunicator(job, jobTask, process, new StateStreamer(client), dataCountsReporter, processor, handler, + xContentRegistry, autodetectWorkerExecutor); } diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/BlackHoleAutodetectProcess.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/BlackHoleAutodetectProcess.java index c33a0b9ea5f..c86e205c067 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/BlackHoleAutodetectProcess.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/BlackHoleAutodetectProcess.java @@ -7,9 +7,11 @@ package org.elasticsearch.xpack.ml.job.process.autodetect; import org.elasticsearch.xpack.ml.job.config.DetectionRule; import org.elasticsearch.xpack.ml.job.config.ModelPlotConfig; +import org.elasticsearch.xpack.ml.job.persistence.StateStreamer; import org.elasticsearch.xpack.ml.job.process.autodetect.output.FlushAcknowledgement; import org.elasticsearch.xpack.ml.job.process.autodetect.params.DataLoadParams; import org.elasticsearch.xpack.ml.job.process.autodetect.params.InterimResultsParams; +import org.elasticsearch.xpack.ml.job.process.autodetect.state.ModelSnapshot; import org.elasticsearch.xpack.ml.job.process.autodetect.state.Quantiles; import org.elasticsearch.xpack.ml.job.results.AutodetectResult; @@ -43,6 +45,15 @@ public class BlackHoleAutodetectProcess implements AutodetectProcess { startTime = ZonedDateTime.now(); } + @Override + public void restoreState(StateStreamer stateStreamer, ModelSnapshot modelSnapshot) { + } + + @Override + public boolean isReady() { + return true; + } + @Override public void writeRecord(String[] record) throws IOException { } diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcess.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcess.java index b199f9787d2..ec3b0b927dd 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcess.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcess.java @@ -11,11 +11,13 @@ import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.job.config.DetectionRule; import org.elasticsearch.xpack.ml.job.config.ModelPlotConfig; +import org.elasticsearch.xpack.ml.job.persistence.StateStreamer; import org.elasticsearch.xpack.ml.job.process.NativeControllerHolder; import org.elasticsearch.xpack.ml.job.process.autodetect.output.AutodetectResultsParser; import org.elasticsearch.xpack.ml.job.process.autodetect.output.StateProcessor; import org.elasticsearch.xpack.ml.job.process.autodetect.params.DataLoadParams; import org.elasticsearch.xpack.ml.job.process.autodetect.params.InterimResultsParams; +import org.elasticsearch.xpack.ml.job.process.autodetect.state.ModelSnapshot; import org.elasticsearch.xpack.ml.job.process.autodetect.writer.ControlMsgToProcessWriter; import org.elasticsearch.xpack.ml.job.process.autodetect.writer.LengthEncodedWriter; import org.elasticsearch.xpack.ml.job.process.logging.CppLogMessageHandler; @@ -49,6 +51,7 @@ class NativeAutodetectProcess implements AutodetectProcess { private final CppLogMessageHandler cppLogHandler; private final OutputStream processInStream; private final InputStream processOutStream; + private final OutputStream processRestoreStream; private final LengthEncodedWriter recordWriter; private final ZonedDateTime startTime; private final int numberOfAnalysisFields; @@ -58,16 +61,17 @@ class NativeAutodetectProcess implements AutodetectProcess { private volatile Future stateProcessorFuture; private volatile boolean processCloseInitiated; private volatile boolean processKilled; + private volatile boolean isReady; private final AutodetectResultsParser resultsParser; - NativeAutodetectProcess(String jobId, InputStream logStream, OutputStream processInStream, - InputStream processOutStream, int numberOfAnalysisFields, - List filesToDelete, AutodetectResultsParser resultsParser, - Runnable onProcessCrash) { + NativeAutodetectProcess(String jobId, InputStream logStream, OutputStream processInStream, InputStream processOutStream, + OutputStream processRestoreStream, int numberOfAnalysisFields, List filesToDelete, + AutodetectResultsParser resultsParser, Runnable onProcessCrash) { this.jobId = jobId; cppLogHandler = new CppLogMessageHandler(jobId, logStream); this.processInStream = new BufferedOutputStream(processInStream); this.processOutStream = processOutStream; + this.processRestoreStream = processRestoreStream; this.recordWriter = new LengthEncodedWriter(this.processInStream); startTime = ZonedDateTime.now(); this.numberOfAnalysisFields = numberOfAnalysisFields; @@ -107,6 +111,26 @@ class NativeAutodetectProcess implements AutodetectProcess { }); } + @Override + public void restoreState(StateStreamer stateStreamer, ModelSnapshot modelSnapshot) { + if (modelSnapshot != null) { + try (OutputStream r = processRestoreStream) { + stateStreamer.restoreStateToStream(jobId, modelSnapshot, r); + } catch (Exception e) { + // TODO: should we fail to start? + if (processKilled == false) { + LOGGER.error("Error restoring model state for job " + jobId, e); + } + } + } + isReady = true; + } + + @Override + public boolean isReady() { + return isReady; + } + @Override public void writeRecord(String[] record) throws IOException { recordWriter.writeRecord(record); diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessFactory.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessFactory.java index 8ff9808a823..a78c5b20f99 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessFactory.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessFactory.java @@ -26,7 +26,6 @@ import org.elasticsearch.xpack.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.utils.NamedPipeHelper; import java.io.IOException; -import java.io.OutputStream; import java.nio.file.Path; import java.time.Duration; import java.util.ArrayList; @@ -73,19 +72,10 @@ public class NativeAutodetectProcessFactory implements AutodetectProcessFactory AutodetectResultsParser resultsParser = new AutodetectResultsParser(settings); NativeAutodetectProcess autodetect = new NativeAutodetectProcess( job.getId(), processPipes.getLogStream().get(), processPipes.getProcessInStream().get(), - processPipes.getProcessOutStream().get(), numberOfAnalysisFields, filesToDelete, - resultsParser, onProcessCrash - ); + processPipes.getProcessOutStream().get(), processPipes.getRestoreStream().orElse(null), numberOfAnalysisFields, + filesToDelete, resultsParser, onProcessCrash); try { autodetect.start(executorService, stateProcessor, processPipes.getPersistStream().get()); - if (modelSnapshot != null) { - try (OutputStream r = processPipes.getRestoreStream().get()) { - jobProvider.restoreStateToStream(job.getId(), modelSnapshot, r); - } catch (Exception e) { - // TODO: should we fail to start? - LOGGER.error("Error restoring model state for job " + job.getId(), e); - } - } return autodetect; } catch (EsRejectedExecutionException e) { try { diff --git a/plugin/src/test/java/org/elasticsearch/xpack/ml/action/CloseJobActionRequestTests.java b/plugin/src/test/java/org/elasticsearch/xpack/ml/action/CloseJobActionRequestTests.java index 2d6360a9905..196b59e0b39 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/ml/action/CloseJobActionRequestTests.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/ml/action/CloseJobActionRequestTests.java @@ -100,10 +100,7 @@ public class CloseJobActionRequestTests extends AbstractStreamableXContentTestCa PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); addJobTask("opening-job", null, null, tasksBuilder); - ElasticsearchStatusException conflictException = - expectThrows(ElasticsearchStatusException.class, () -> - CloseJobAction.validateJobAndTaskState("opening-job", mlBuilder.build(), tasksBuilder.build())); - assertEquals(RestStatus.CONFLICT, conflictException.status()); + CloseJobAction.validateJobAndTaskState("opening-job", mlBuilder.build(), tasksBuilder.build()); } public void testValidate_jobIsMissing() { diff --git a/plugin/src/test/java/org/elasticsearch/xpack/ml/job/persistence/JobProviderTests.java b/plugin/src/test/java/org/elasticsearch/xpack/ml/job/persistence/JobProviderTests.java index 58a9384a167..27bb8e6dda5 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/ml/job/persistence/JobProviderTests.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/ml/job/persistence/JobProviderTests.java @@ -41,9 +41,7 @@ import org.elasticsearch.xpack.ml.MlMetadata; import org.elasticsearch.xpack.ml.action.util.QueryPage; import org.elasticsearch.xpack.ml.job.config.Job; import org.elasticsearch.xpack.ml.job.persistence.InfluencersQueryBuilder.InfluencersQuery; -import org.elasticsearch.xpack.ml.job.process.autodetect.state.CategorizerState; import org.elasticsearch.xpack.ml.job.process.autodetect.state.ModelSnapshot; -import org.elasticsearch.xpack.ml.job.process.autodetect.state.ModelState; import org.elasticsearch.xpack.ml.job.results.AnomalyRecord; import org.elasticsearch.xpack.ml.job.results.Bucket; import org.elasticsearch.xpack.ml.job.results.CategoryDefinition; @@ -51,9 +49,7 @@ import org.elasticsearch.xpack.ml.job.results.Influencer; import org.elasticsearch.xpack.ml.job.results.Result; import org.mockito.ArgumentCaptor; -import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; import java.util.Date; @@ -804,43 +800,6 @@ public class JobProviderTests extends ESTestCase { assertEquals(6, snapshots.get(1).getSnapshotDocCount()); } - public void testRestoreStateToStream() throws Exception { - String snapshotId = "123"; - Map categorizerState = new HashMap<>(); - categorizerState.put("catName", "catVal"); - GetResponse categorizerStateGetResponse1 = createGetResponse(true, categorizerState); - GetResponse categorizerStateGetResponse2 = createGetResponse(false, null); - Map modelState = new HashMap<>(); - modelState.put("modName", "modVal1"); - GetResponse modelStateGetResponse1 = createGetResponse(true, modelState); - modelState.put("modName", "modVal2"); - GetResponse modelStateGetResponse2 = createGetResponse(true, modelState); - - MockClientBuilder clientBuilder = new MockClientBuilder(CLUSTER_NAME).addClusterStatusYellowResponse() - .prepareGet(AnomalyDetectorsIndex.jobStateIndexName(), ElasticsearchMappings.DOC_TYPE, - CategorizerState.documentId(JOB_ID, 1), categorizerStateGetResponse1) - .prepareGet(AnomalyDetectorsIndex.jobStateIndexName(), ElasticsearchMappings.DOC_TYPE, - CategorizerState.documentId(JOB_ID, 2), categorizerStateGetResponse2) - .prepareGet(AnomalyDetectorsIndex.jobStateIndexName(), ElasticsearchMappings.DOC_TYPE, - ModelState.documentId(JOB_ID, snapshotId, 1), modelStateGetResponse1) - .prepareGet(AnomalyDetectorsIndex.jobStateIndexName(), ElasticsearchMappings.DOC_TYPE, - ModelState.documentId(JOB_ID, snapshotId, 2), modelStateGetResponse2); - - JobProvider provider = createProvider(clientBuilder.build()); - - ModelSnapshot modelSnapshot = new ModelSnapshot.Builder(JOB_ID).setSnapshotId(snapshotId).setSnapshotDocCount(2).build(); - - ByteArrayOutputStream stream = new ByteArrayOutputStream(); - - provider.restoreStateToStream(JOB_ID, modelSnapshot, stream); - - String[] restoreData = stream.toString(StandardCharsets.UTF_8.name()).split("\0"); - assertEquals(3, restoreData.length); - assertEquals("{\"modName\":\"modVal1\"}", restoreData[0]); - assertEquals("{\"modName\":\"modVal2\"}", restoreData[1]); - assertEquals("{\"catName\":\"catVal\"}", restoreData[2]); - } - public void testViolatedFieldCountLimit() throws Exception { Map mapping = new HashMap<>(); for (int i = 0; i < 10; i++) { diff --git a/plugin/src/test/java/org/elasticsearch/xpack/ml/job/persistence/StateStreamerTests.java b/plugin/src/test/java/org/elasticsearch/xpack/ml/job/persistence/StateStreamerTests.java new file mode 100644 index 00000000000..10d86cfcc21 --- /dev/null +++ b/plugin/src/test/java/org/elasticsearch/xpack/ml/job/persistence/StateStreamerTests.java @@ -0,0 +1,86 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.job.persistence; + +import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.mock.orig.Mockito; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.ml.job.process.autodetect.state.CategorizerState; +import org.elasticsearch.xpack.ml.job.process.autodetect.state.ModelSnapshot; +import org.elasticsearch.xpack.ml.job.process.autodetect.state.ModelState; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class StateStreamerTests extends ESTestCase { + + private static final String CLUSTER_NAME = "state_streamer_cluster"; + private static final String JOB_ID = "state_streamer_test_job"; + + public void testRestoreStateToStream() throws Exception { + String snapshotId = "123"; + Map categorizerState = new HashMap<>(); + categorizerState.put("catName", "catVal"); + GetResponse categorizerStateGetResponse1 = createGetResponse(true, categorizerState); + GetResponse categorizerStateGetResponse2 = createGetResponse(false, null); + Map modelState = new HashMap<>(); + modelState.put("modName", "modVal1"); + GetResponse modelStateGetResponse1 = createGetResponse(true, modelState); + modelState.put("modName", "modVal2"); + GetResponse modelStateGetResponse2 = createGetResponse(true, modelState); + + MockClientBuilder clientBuilder = new MockClientBuilder(CLUSTER_NAME).addClusterStatusYellowResponse() + .prepareGet(AnomalyDetectorsIndex.jobStateIndexName(), ElasticsearchMappings.DOC_TYPE, + CategorizerState.documentId(JOB_ID, 1), categorizerStateGetResponse1) + .prepareGet(AnomalyDetectorsIndex.jobStateIndexName(), ElasticsearchMappings.DOC_TYPE, + CategorizerState.documentId(JOB_ID, 2), categorizerStateGetResponse2) + .prepareGet(AnomalyDetectorsIndex.jobStateIndexName(), ElasticsearchMappings.DOC_TYPE, + ModelState.documentId(JOB_ID, snapshotId, 1), modelStateGetResponse1) + .prepareGet(AnomalyDetectorsIndex.jobStateIndexName(), ElasticsearchMappings.DOC_TYPE, + ModelState.documentId(JOB_ID, snapshotId, 2), modelStateGetResponse2); + + + ModelSnapshot modelSnapshot = new ModelSnapshot.Builder(JOB_ID).setSnapshotId(snapshotId).setSnapshotDocCount(2).build(); + + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + + StateStreamer stateStreamer = new StateStreamer(clientBuilder.build()); + stateStreamer.restoreStateToStream(JOB_ID, modelSnapshot, stream); + + String[] restoreData = stream.toString(StandardCharsets.UTF_8.name()).split("\0"); + assertEquals(3, restoreData.length); + assertEquals("{\"modName\":\"modVal1\"}", restoreData[0]); + assertEquals("{\"modName\":\"modVal2\"}", restoreData[1]); + assertEquals("{\"catName\":\"catVal\"}", restoreData[2]); + } + + public void testCancelBeforeRestoreWasCalled() throws IOException { + ModelSnapshot modelSnapshot = new ModelSnapshot.Builder(JOB_ID).setSnapshotId("snapshot_id").setSnapshotDocCount(2).build(); + OutputStream outputStream = mock(OutputStream.class); + StateStreamer stateStreamer = new StateStreamer(mock(Client.class)); + stateStreamer.cancel(); + + stateStreamer.restoreStateToStream(JOB_ID, modelSnapshot, outputStream); + + Mockito.verifyNoMoreInteractions(outputStream); + } + + private static GetResponse createGetResponse(boolean exists, Map source) throws IOException { + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(exists); + when(getResponse.getSourceAsBytesRef()).thenReturn(XContentFactory.jsonBuilder().map(source).bytes()); + return getResponse; + } +} \ No newline at end of file diff --git a/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicatorTests.java b/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicatorTests.java index 404c420ea44..c1930bee8fb 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicatorTests.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicatorTests.java @@ -15,11 +15,13 @@ import org.elasticsearch.xpack.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.ml.job.config.DataDescription; import org.elasticsearch.xpack.ml.job.config.Detector; import org.elasticsearch.xpack.ml.job.config.Job; +import org.elasticsearch.xpack.ml.job.persistence.StateStreamer; import org.elasticsearch.xpack.ml.job.process.DataCountsReporter; import org.elasticsearch.xpack.ml.job.process.autodetect.output.AutoDetectResultProcessor; import org.elasticsearch.xpack.ml.job.process.autodetect.params.DataLoadParams; import org.elasticsearch.xpack.ml.job.process.autodetect.params.InterimResultsParams; import org.elasticsearch.xpack.ml.job.process.autodetect.params.TimeRange; +import org.junit.Before; import org.mockito.Mockito; import java.io.ByteArrayInputStream; @@ -47,6 +49,13 @@ import static org.mockito.Mockito.when; public class AutodetectCommunicatorTests extends ESTestCase { + private StateStreamer stateStreamer; + + @Before + public void initMocks() { + stateStreamer = mock(StateStreamer.class); + } + public void testWriteResetBucketsControlMessage() throws IOException { DataLoadParams params = new DataLoadParams(TimeRange.builder().startTime("1").endTime("2").build(), Optional.empty()); AutodetectProcess process = mockAutodetectProcessWithOutputStream(); @@ -107,11 +116,28 @@ public class AutodetectCommunicatorTests extends ESTestCase { verify(process, times(3)).isProcessAlive(); } - public void testClose() throws IOException { + public void testCloseGivenProcessIsReady() throws IOException { AutodetectProcess process = mockAutodetectProcessWithOutputStream(); + when(process.isReady()).thenReturn(true); AutodetectCommunicator communicator = createAutodetectCommunicator(process, mock(AutoDetectResultProcessor.class)); + communicator.close(); - Mockito.verify(process).close(); + + verify(process).close(); + verify(process, never()).kill(); + Mockito.verifyNoMoreInteractions(stateStreamer); + } + + public void testCloseGivenProcessIsNotReady() throws IOException { + AutodetectProcess process = mockAutodetectProcessWithOutputStream(); + when(process.isReady()).thenReturn(false); + AutodetectCommunicator communicator = createAutodetectCommunicator(process, mock(AutoDetectResultProcessor.class)); + + communicator.close(); + + verify(process).kill(); + verify(process, never()).close(); + verify(stateStreamer).cancel(); } public void testKill() throws IOException, TimeoutException { @@ -167,7 +193,7 @@ public class AutodetectCommunicatorTests extends ESTestCase { }).when(dataCountsReporter).finishReporting(any()); JobTask jobTask = mock(JobTask.class); when(jobTask.getJobId()).thenReturn("foo"); - return new AutodetectCommunicator(createJobDetails(), jobTask, autodetectProcess, + return new AutodetectCommunicator(createJobDetails(), jobTask, autodetectProcess, stateStreamer, dataCountsReporter, autoDetectResultProcessor, finishHandler, new NamedXContentRegistry(Collections.emptyList()), executorService); } diff --git a/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessTests.java b/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessTests.java index 94e55464bc6..8dcc333c1dc 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessTests.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessTests.java @@ -53,7 +53,7 @@ public class NativeAutodetectProcessTests extends ESTestCase { InputStream logStream = mock(InputStream.class); when(logStream.read(new byte[1024])).thenReturn(-1); try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream, - mock(OutputStream.class), mock(InputStream.class), + mock(OutputStream.class), mock(InputStream.class), mock(OutputStream.class), NUMBER_ANALYSIS_FIELDS, null, new AutodetectResultsParser(Settings.EMPTY), mock(Runnable.class))) { process.start(executorService, mock(StateProcessor.class), mock(InputStream.class)); @@ -74,7 +74,7 @@ public class NativeAutodetectProcessTests extends ESTestCase { String[] record = {"r1", "r2", "r3", "r4", "r5"}; ByteArrayOutputStream bos = new ByteArrayOutputStream(1024); try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream, - bos, mock(InputStream.class), NUMBER_ANALYSIS_FIELDS, Collections.emptyList(), + bos, mock(InputStream.class), mock(OutputStream.class), NUMBER_ANALYSIS_FIELDS, Collections.emptyList(), new AutodetectResultsParser(Settings.EMPTY), mock(Runnable.class))) { process.start(executorService, mock(StateProcessor.class), mock(InputStream.class)); @@ -106,7 +106,7 @@ public class NativeAutodetectProcessTests extends ESTestCase { when(logStream.read(new byte[1024])).thenReturn(-1); ByteArrayOutputStream bos = new ByteArrayOutputStream(ControlMsgToProcessWriter.FLUSH_SPACES_LENGTH + 1024); try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream, - bos, mock(InputStream.class), NUMBER_ANALYSIS_FIELDS, Collections.emptyList(), + bos, mock(InputStream.class), mock(OutputStream.class), NUMBER_ANALYSIS_FIELDS, Collections.emptyList(), new AutodetectResultsParser(Settings.EMPTY), mock(Runnable.class))) { process.start(executorService, mock(StateProcessor.class), mock(InputStream.class)); @@ -123,7 +123,7 @@ public class NativeAutodetectProcessTests extends ESTestCase { when(logStream.read(new byte[1024])).thenReturn(-1); ByteArrayOutputStream bos = new ByteArrayOutputStream(1024); try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream, - bos, mock(InputStream.class), NUMBER_ANALYSIS_FIELDS, Collections.emptyList(), + bos, mock(InputStream.class), mock(OutputStream.class), NUMBER_ANALYSIS_FIELDS, Collections.emptyList(), new AutodetectResultsParser(Settings.EMPTY), mock(Runnable.class))) { process.start(executorService, mock(StateProcessor.class), mock(InputStream.class)); @@ -141,7 +141,7 @@ public class NativeAutodetectProcessTests extends ESTestCase { when(logStream.read(new byte[1024])).thenReturn(-1); ByteArrayOutputStream bos = new ByteArrayOutputStream(1024); try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream, - bos, mock(InputStream.class), NUMBER_ANALYSIS_FIELDS, Collections.emptyList(), + bos, mock(InputStream.class), mock(OutputStream.class), NUMBER_ANALYSIS_FIELDS, Collections.emptyList(), new AutodetectResultsParser(Settings.EMPTY), mock(Runnable.class))) { process.start(executorService, mock(StateProcessor.class), mock(InputStream.class));