diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicator.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicator.java index 23015dae064..0bffeda4283 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicator.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicator.java @@ -35,8 +35,7 @@ import java.time.Duration; import java.time.ZonedDateTime; import java.util.List; import java.util.Optional; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.function.Supplier; @@ -52,7 +51,7 @@ public class AutodetectCommunicator implements Closeable { private final AutoDetectResultProcessor autoDetectResultProcessor; private final Consumer handler; - final AtomicReference inUse = new AtomicReference<>(); + final AtomicBoolean inUse = new AtomicBoolean(false); public AutodetectCommunicator(long taskId, Job job, AutodetectProcess process, DataCountsReporter dataCountsReporter, AutoDetectResultProcessor autoDetectResultProcessor, Consumer handler) { @@ -84,7 +83,7 @@ public class AutodetectCommunicator implements Closeable { DataCounts results = autoDetectWriter.write(countingStream); autoDetectWriter.flush(); return results; - }, false); + }); } @Override @@ -99,22 +98,21 @@ public class AutodetectCommunicator implements Closeable { autoDetectResultProcessor.awaitCompletion(); handler.accept(errorReason != null ? new ElasticsearchException(errorReason) : null); return null; - }, true); + }); } - public void writeUpdateModelDebugMessage(ModelDebugConfig config) throws IOException { checkAndRun(() -> Messages.getMessage(Messages.JOB_DATA_CONCURRENT_USE_UPDATE, job.getId()), () -> { autodetectProcess.writeUpdateModelDebugMessage(config); return null; - }, false); + }); } public void writeUpdateDetectorRulesMessage(int detectorIndex, List rules) throws IOException { checkAndRun(() -> Messages.getMessage(Messages.JOB_DATA_CONCURRENT_USE_UPDATE, job.getId()), () -> { autodetectProcess.writeUpdateDetectorRulesMessage(detectorIndex, rules); return null; - }, false); + }); } public void flushJob(InterimResultsParams params) throws IOException { @@ -122,7 +120,7 @@ public class AutodetectCommunicator implements Closeable { String flushId = autodetectProcess.flushJob(params); waitFlushToCompletion(flushId); return null; - }, false); + }); } private void waitFlushToCompletion(String flushId) throws IOException { @@ -173,32 +171,16 @@ public class AutodetectCommunicator implements Closeable { return taskId; } - private T checkAndRun(Supplier errorMessage, CheckedSupplier callback, boolean wait) throws IOException { - CountDownLatch latch = new CountDownLatch(1); - if (inUse.compareAndSet(null, latch)) { + private T checkAndRun(Supplier errorMessage, CheckedSupplier callback) throws IOException { + if (inUse.compareAndSet(false, true)) { try { checkProcessIsAlive(); return callback.get(); } finally { - latch.countDown(); - inUse.set(null); + inUse.set(false); } } else { - if (wait) { - latch = inUse.get(); - if (latch != null) { - try { - latch.await(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new ElasticsearchStatusException(errorMessage.get(), RestStatus.TOO_MANY_REQUESTS); - } - } - checkProcessIsAlive(); - return callback.get(); - } else { - throw new ElasticsearchStatusException(errorMessage.get(), RestStatus.TOO_MANY_REQUESTS); - } + throw new ElasticsearchStatusException(errorMessage.get(), RestStatus.TOO_MANY_REQUESTS); } } diff --git a/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicatorTests.java b/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicatorTests.java index 3f389eba20d..cb7996f3ede 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicatorTests.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicatorTests.java @@ -28,7 +28,6 @@ import java.time.Duration; import java.util.Collections; import java.util.List; import java.util.Optional; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import static org.elasticsearch.mock.orig.Mockito.doAnswer; @@ -156,11 +155,11 @@ public class AutodetectCommunicatorTests extends ESTestCase { AutodetectProcess process = mockAutodetectProcessWithOutputStream(); AutodetectCommunicator communicator = createAutodetectCommunicator(process, mock(AutoDetectResultProcessor.class)); - communicator.inUse.set(new CountDownLatch(1)); + communicator.inUse.set(true); expectThrows(ElasticsearchStatusException.class, () -> communicator.writeToJob(in, mock(DataLoadParams.class))); - communicator.inUse.set(null); + communicator.inUse.set(false); communicator.writeToJob(in, new DataLoadParams(TimeRange.builder().build(), Optional.empty())); } @@ -170,11 +169,11 @@ public class AutodetectCommunicatorTests extends ESTestCase { when(resultProcessor.waitForFlushAcknowledgement(any(), any())).thenReturn(true); AutodetectCommunicator communicator = createAutodetectCommunicator(process, resultProcessor); - communicator.inUse.set(new CountDownLatch(1)); + communicator.inUse.set(true); InterimResultsParams params = mock(InterimResultsParams.class); expectThrows(ElasticsearchStatusException.class, () -> communicator.flushJob(params)); - communicator.inUse.set(null); + communicator.inUse.set(false); communicator.flushJob(params); } @@ -184,12 +183,10 @@ public class AutodetectCommunicatorTests extends ESTestCase { when(resultProcessor.waitForFlushAcknowledgement(any(), any())).thenReturn(true); AutodetectCommunicator communicator = createAutodetectCommunicator(process, resultProcessor); - CountDownLatch latch = mock(CountDownLatch.class); - communicator.inUse.set(latch); - communicator.close(); - verify(latch, times(1)).await(); + communicator.inUse.set(true); + expectThrows(ElasticsearchStatusException.class, () -> communicator.close()); - communicator.inUse.set(null); + communicator.inUse.set(false); communicator.close(); } @@ -198,10 +195,10 @@ public class AutodetectCommunicatorTests extends ESTestCase { AutoDetectResultProcessor resultProcessor = mock(AutoDetectResultProcessor.class); AutodetectCommunicator communicator = createAutodetectCommunicator(process, resultProcessor); - communicator.inUse.set(new CountDownLatch(1)); + communicator.inUse.set(true); expectThrows(ElasticsearchStatusException.class, () -> communicator.writeUpdateModelDebugMessage(mock(ModelDebugConfig.class))); - communicator.inUse.set(null); + communicator.inUse.set(false); communicator.writeUpdateModelDebugMessage(mock(ModelDebugConfig.class)); } @@ -211,10 +208,10 @@ public class AutodetectCommunicatorTests extends ESTestCase { AutodetectCommunicator communicator = createAutodetectCommunicator(process, resultProcessor); List rules = Collections.singletonList(mock(DetectionRule.class)); - communicator.inUse.set(new CountDownLatch(1)); + communicator.inUse.set(true); expectThrows(ElasticsearchStatusException.class, () -> communicator.writeUpdateDetectorRulesMessage(0, rules)); - communicator.inUse.set(null); + communicator.inUse.set(false); communicator.writeUpdateDetectorRulesMessage(0, rules); } }