From 5cee9f614842e6dd1cb26b7c06733adee0d5abf8 Mon Sep 17 00:00:00 2001 From: Karan Kumar Date: Fri, 22 Sep 2023 11:21:04 +0530 Subject: [PATCH] Allow cancellation of MSQ tasks if they are waiting for segments to load (#15000) With PR #14322 , MSQ insert/Replace q's will wait for segment to be loaded on the historical's before finishing. The patch introduces a bug where in the main thread had a thread.sleep() which could not be interrupted via the cancel calls from the overlord. This new patch addressed that problem by moving the thread.sleep inside a thread of its own. Thus the main thread is now waiting on the future object of this execution. The cancel call can now shutdown the executor service via another method thus unblocking the main thread to proceed. --- .../apache/druid/msq/exec/ControllerImpl.java | 29 +++- ...ter.java => SegmentLoadStatusFetcher.java} | 145 ++++++++++++------ .../druid/msq/exec/WorkerSketchFetcher.java | 7 +- .../msq/indexing/report/MSQStatusReport.java | 8 +- ...java => SegmentLoadStatusFetcherTest.java} | 89 ++++++++++- .../indexing/report/MSQTaskReportTest.java | 14 +- 6 files changed, 219 insertions(+), 73 deletions(-) rename extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/{SegmentLoadWaiter.java => SegmentLoadStatusFetcher.java} (77%) rename extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/{SegmentLoadWaiterTest.java => SegmentLoadStatusFetcherTest.java} (54%) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index 20033480f10..f81b1e8a827 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -294,7 +294,7 @@ public class ControllerImpl implements Controller private WorkerMemoryParameters workerMemoryParameters; private boolean isDurableStorageEnabled; private boolean isFaultToleranceEnabled; - private volatile SegmentLoadWaiter segmentLoadWaiter; + private volatile SegmentLoadStatusFetcher segmentLoadWaiter; public ControllerImpl( final MSQControllerTask task, @@ -354,6 +354,7 @@ public class ControllerImpl implements Controller // stopGracefully() is called when the containing process is terminated, or when the task is canceled. log.info("Query [%s] canceled.", queryDef != null ? queryDef.getQueryId() : ""); + stopExternalFetchers(); addToKernelManipulationQueue( kernel -> { throw new MSQException(CanceledFault.INSTANCE); @@ -465,7 +466,6 @@ public class ControllerImpl implements Controller try { releaseTaskLocks(); - cleanUpDurableStorageIfNeeded(); if (queryKernel != null && queryKernel.isSuccess()) { @@ -474,6 +474,7 @@ public class ControllerImpl implements Controller segmentLoadWaiter.waitForSegmentsToLoad(); } } + stopExternalFetchers(); } catch (Exception e) { log.warn(e, "Exception thrown during cleanup. Ignoring it and writing task report."); @@ -742,7 +743,7 @@ public class ControllerImpl implements Controller /** * Accepts a {@link PartialKeyStatisticsInformation} and updates the controller key statistics information. If all key * statistics information has been gathered, enqueues the task with the {@link WorkerSketchFetcher} to generate - * partiton boundaries. This is intended to be called by the {@link ControllerChatHandler}. + * partition boundaries. This is intended to be called by the {@link ControllerChatHandler}. */ @Override public void updatePartialKeyStatisticsInformation( @@ -801,7 +802,7 @@ public class ControllerImpl implements Controller /** * This method intakes all the warnings that are generated by the worker. It is the responsibility of the - * worker node to ensure that it doesn't spam the controller with unneseccary warning stack traces. Currently, that + * worker node to ensure that it doesn't spam the controller with unnecessary warning stack traces. Currently, that * limiting is implemented in {@link MSQWarningReportLimiterPublisher} */ @Override @@ -1360,9 +1361,10 @@ public class ControllerImpl implements Controller } } else { Set versionsToAwait = segmentsWithTombstones.stream().map(DataSegment::getVersion).collect(Collectors.toSet()); - segmentLoadWaiter = new SegmentLoadWaiter( + segmentLoadWaiter = new SegmentLoadStatusFetcher( context.injector().getInstance(BrokerClient.class), context.jsonMapper(), + task.getId(), task.getDataSource(), versionsToAwait, segmentsWithTombstones.size(), @@ -1375,9 +1377,10 @@ public class ControllerImpl implements Controller } } else if (!segments.isEmpty()) { Set versionsToAwait = segments.stream().map(DataSegment::getVersion).collect(Collectors.toSet()); - segmentLoadWaiter = new SegmentLoadWaiter( + segmentLoadWaiter = new SegmentLoadStatusFetcher( context.injector().getInstance(BrokerClient.class), context.jsonMapper(), + task.getId(), task.getDataSource(), versionsToAwait, segments.size(), @@ -2129,7 +2132,7 @@ public class ControllerImpl implements Controller @Nullable final DateTime queryStartTime, final long queryDuration, MSQWorkerTaskLauncher taskLauncher, - final SegmentLoadWaiter segmentLoadWaiter + final SegmentLoadStatusFetcher segmentLoadWaiter ) { int pendingTasks = -1; @@ -2141,7 +2144,7 @@ public class ControllerImpl implements Controller runningTasks = workerTaskCount.getRunningWorkerCount() + 1; // To account for controller. } - SegmentLoadWaiter.SegmentLoadWaiterStatus status = segmentLoadWaiter == null ? null : segmentLoadWaiter.status(); + SegmentLoadStatusFetcher.SegmentLoadWaiterStatus status = segmentLoadWaiter == null ? null : segmentLoadWaiter.status(); return new MSQStatusReport( taskState, @@ -2260,6 +2263,16 @@ public class ControllerImpl implements Controller } } + private void stopExternalFetchers() + { + if (workerSketchFetcher != null) { + workerSketchFetcher.close(); + } + if (segmentLoadWaiter != null) { + segmentLoadWaiter.close(); + } + } + /** * Main controller logic for running a multi-stage query. */ diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadWaiter.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java similarity index 77% rename from extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadWaiter.java rename to extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java index 3a54c41e410..478c632a749 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadWaiter.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcher.java @@ -24,9 +24,13 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; +import org.apache.druid.common.guava.FutureUtils; import org.apache.druid.discovery.BrokerClient; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.http.client.Request; import org.apache.druid.sql.http.ResultFormat; @@ -56,9 +60,9 @@ import java.util.concurrent.atomic.AtomicReference; * If the segments are not loaded within {@link #TIMEOUT_DURATION_MILLIS} milliseconds, this logs a warning and exits * for the same reason. */ -public class SegmentLoadWaiter +public class SegmentLoadStatusFetcher implements AutoCloseable { - private static final Logger log = new Logger(SegmentLoadWaiter.class); + private static final Logger log = new Logger(SegmentLoadStatusFetcher.class); private static final long SLEEP_DURATION_MILLIS = TimeUnit.SECONDS.toMillis(5); private static final long TIMEOUT_DURATION_MILLIS = TimeUnit.MINUTES.toMillis(10); @@ -68,9 +72,9 @@ public class SegmentLoadWaiter * - If a segment is not used, the broker will not have any information about it, hence, a COUNT(*) should return the used count only. * - If replication_factor is more than 0, the segment will be loaded on historicals and needs to be waited for. * - If replication_factor is 0, that means that the segment will never be loaded on a historical and does not need to - * be waited for. + * be waited for. * - If replication_factor is -1, the replication factor is not known currently and will become known after a load rule - * evaluation. + * evaluation. *
* See https://github.com/apache/druid/pull/14403 for more details about replication_factor */ @@ -90,11 +94,15 @@ public class SegmentLoadWaiter private final Set versionsToAwait; private final int totalSegmentsGenerated; private final boolean doWait; + // since live reports fetch the value in another thread, we need to use AtomicReference private final AtomicReference status; - public SegmentLoadWaiter( + private final ListeningExecutorService executorService; + + public SegmentLoadStatusFetcher( BrokerClient brokerClient, ObjectMapper objectMapper, + String taskId, String datasource, Set versionsToAwait, int totalSegmentsGenerated, @@ -107,8 +115,19 @@ public class SegmentLoadWaiter this.versionsToAwait = new TreeSet<>(versionsToAwait); this.versionToLoadStatusMap = new HashMap<>(); this.totalSegmentsGenerated = totalSegmentsGenerated; - this.status = new AtomicReference<>(new SegmentLoadWaiterStatus(State.INIT, null, 0, totalSegmentsGenerated, 0, 0, 0, 0, totalSegmentsGenerated)); + this.status = new AtomicReference<>(new SegmentLoadWaiterStatus( + State.INIT, + null, + 0, + totalSegmentsGenerated, + 0, + 0, + 0, + 0, + totalSegmentsGenerated + )); this.doWait = doWait; + this.executorService = MoreExecutors.listeningDecorator(Execs.singleThreaded(taskId + "-segment-load-waiter-%d")); } /** @@ -122,57 +141,73 @@ public class SegmentLoadWaiter */ public void waitForSegmentsToLoad() { - DateTime startTime = DateTimes.nowUtc(); - boolean hasAnySegmentBeenLoaded = false; - + final DateTime startTime = DateTimes.nowUtc(); + final AtomicReference hasAnySegmentBeenLoaded = new AtomicReference<>(false); try { - while (!versionsToAwait.isEmpty()) { - // Check the timeout and exit if exceeded. - long runningMillis = new Interval(startTime, DateTimes.nowUtc()).toDurationMillis(); - if (runningMillis > TIMEOUT_DURATION_MILLIS) { - log.warn("Runtime [%s] exceeded timeout [%s] while waiting for segments to load. Exiting.", runningMillis, TIMEOUT_DURATION_MILLIS); - updateStatus(State.TIMED_OUT, startTime); - return; - } + FutureUtils.getUnchecked(executorService.submit(() -> { + try { + while (!versionsToAwait.isEmpty()) { + // Check the timeout and exit if exceeded. + long runningMillis = new Interval(startTime, DateTimes.nowUtc()).toDurationMillis(); + if (runningMillis > TIMEOUT_DURATION_MILLIS) { + log.warn( + "Runtime[%d] exceeded timeout[%d] while waiting for segments to load. Exiting.", + runningMillis, + TIMEOUT_DURATION_MILLIS + ); + updateStatus(State.TIMED_OUT, startTime); + return; + } - Iterator iterator = versionsToAwait.iterator(); + Iterator iterator = versionsToAwait.iterator(); + log.info( + "Fetching segment load status for datasource[%s] from broker for segment versions[%s]", + datasource, + versionsToAwait + ); - // Query the broker for all pending versions - while (iterator.hasNext()) { - String version = iterator.next(); + // Query the broker for all pending versions + while (iterator.hasNext()) { + String version = iterator.next(); - // Fetch the load status for this version from the broker - VersionLoadStatus loadStatus = fetchLoadStatusForVersion(version); - versionToLoadStatusMap.put(version, loadStatus); + // Fetch the load status for this version from the broker + VersionLoadStatus loadStatus = fetchLoadStatusForVersion(version); + versionToLoadStatusMap.put(version, loadStatus); + hasAnySegmentBeenLoaded.set(hasAnySegmentBeenLoaded.get() || loadStatus.getUsedSegments() > 0); - hasAnySegmentBeenLoaded = hasAnySegmentBeenLoaded || loadStatus.getUsedSegments() > 0; + // If loading is done for this stage, remove it from future loops. + if (hasAnySegmentBeenLoaded.get() && loadStatus.isLoadingComplete()) { + iterator.remove(); + } + } - // If loading is done for this stage, remove it from future loops. - if (hasAnySegmentBeenLoaded && loadStatus.isLoadingComplete()) { - iterator.remove(); + if (!versionsToAwait.isEmpty()) { + // Update the status. + updateStatus(State.WAITING, startTime); + // Sleep for a bit before checking again. + waitIfNeeded(SLEEP_DURATION_MILLIS); + } } } - - if (!versionsToAwait.isEmpty()) { - // Update the status. - updateStatus(State.WAITING, startTime); - - // Sleep for a while before retrying. - waitIfNeeded(SLEEP_DURATION_MILLIS); + catch (Exception e) { + log.warn(e, "Exception occurred while waiting for segments to load. Exiting."); + // Update the status and return. + updateStatus(State.FAILED, startTime); + return; } - } + // Update the status. + log.info("Segment loading completed for datasource[%s]", datasource); + updateStatus(State.SUCCESS, startTime); + }), true); } catch (Exception e) { log.warn(e, "Exception occurred while waiting for segments to load. Exiting."); - - // Update the status and return. updateStatus(State.FAILED, startTime); - return; } - // Update the status. - updateStatus(State.SUCCESS, startTime); + finally { + executorService.shutdownNow(); + } } - private void waitIfNeeded(long waitTimeMillis) throws Exception { if (doWait) { @@ -219,9 +254,9 @@ public class SegmentLoadWaiter Request request = brokerClient.makeRequest(HttpMethod.POST, "/druid/v2/sql/"); SqlQuery sqlQuery = new SqlQuery(StringUtils.format(LOAD_QUERY, datasource, version), ResultFormat.OBJECTLINES, - false, false, false, null, null); + false, false, false, null, null + ); request.setContent(MediaType.APPLICATION_JSON, objectMapper.writeValueAsBytes(sqlQuery)); - String response = brokerClient.sendQuery(request); if (response.trim().isEmpty()) { @@ -240,6 +275,17 @@ public class SegmentLoadWaiter return status.get(); } + @Override + public void close() + { + try { + executorService.shutdownNow(); + } + catch (Throwable suppressed) { + log.warn(suppressed, "Error shutting down SegmentLoadStatusFetcher"); + } + } + public static class SegmentLoadWaiterStatus { private final State state; @@ -254,7 +300,7 @@ public class SegmentLoadWaiter @JsonCreator public SegmentLoadWaiterStatus( - @JsonProperty("state") SegmentLoadWaiter.State state, + @JsonProperty("state") SegmentLoadStatusFetcher.State state, @JsonProperty("startTime") @Nullable DateTime startTime, @JsonProperty("duration") long duration, @JsonProperty("totalSegments") int totalSegments, @@ -277,7 +323,7 @@ public class SegmentLoadWaiter } @JsonProperty - public SegmentLoadWaiter.State getState() + public SegmentLoadStatusFetcher.State getState() { return state; } @@ -356,7 +402,12 @@ public class SegmentLoadWaiter * The time spent waiting for segments to load exceeded org.apache.druid.msq.exec.SegmentLoadWaiter#TIMEOUT_DURATION_MILLIS. * The SegmentLoadWaiter exited without failing the task. */ - TIMED_OUT + TIMED_OUT; + + public boolean isFinished() + { + return this == SUCCESS || this == FAILED || this == TIMED_OUT; + } } public static class VersionLoadStatus diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java index d9f2291afcd..271ce8ff070 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java @@ -304,6 +304,11 @@ public class WorkerSketchFetcher implements AutoCloseable @Override public void close() { - executorService.shutdownNow(); + try { + executorService.shutdownNow(); + } + catch (Throwable suppressed) { + log.warn(suppressed, "Error while shutting down WorkerSketchFetcher"); + } } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStatusReport.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStatusReport.java index ea721d84f47..d3864498349 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStatusReport.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStatusReport.java @@ -24,7 +24,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; import org.apache.druid.indexer.TaskState; -import org.apache.druid.msq.exec.SegmentLoadWaiter; +import org.apache.druid.msq.exec.SegmentLoadStatusFetcher; import org.apache.druid.msq.indexing.error.MSQErrorReport; import org.joda.time.DateTime; @@ -52,7 +52,7 @@ public class MSQStatusReport private final int runningTasks; @Nullable - private final SegmentLoadWaiter.SegmentLoadWaiterStatus segmentLoadWaiterStatus; + private final SegmentLoadStatusFetcher.SegmentLoadWaiterStatus segmentLoadWaiterStatus; @JsonCreator public MSQStatusReport( @@ -63,7 +63,7 @@ public class MSQStatusReport @JsonProperty("durationMs") long durationMs, @JsonProperty("pendingTasks") int pendingTasks, @JsonProperty("runningTasks") int runningTasks, - @JsonProperty("segmentLoadWaiterStatus") @Nullable SegmentLoadWaiter.SegmentLoadWaiterStatus segmentLoadWaiterStatus + @JsonProperty("segmentLoadWaiterStatus") @Nullable SegmentLoadStatusFetcher.SegmentLoadWaiterStatus segmentLoadWaiterStatus ) { this.status = Preconditions.checkNotNull(status, "status"); @@ -126,7 +126,7 @@ public class MSQStatusReport @Nullable @JsonProperty @JsonInclude(JsonInclude.Include.NON_NULL) - public SegmentLoadWaiter.SegmentLoadWaiterStatus getSegmentLoadWaiterStatus() + public SegmentLoadStatusFetcher.SegmentLoadWaiterStatus getSegmentLoadWaiterStatus() { return segmentLoadWaiterStatus; } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadWaiterTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcherTest.java similarity index 54% rename from extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadWaiterTest.java rename to extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcherTest.java index e14fa5faec2..f2ffa0c9ec7 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadWaiterTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadStatusFetcherTest.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableSet; import org.apache.druid.discovery.BrokerClient; import org.apache.druid.java.util.http.client.Request; +import org.junit.Assert; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -35,11 +36,11 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -public class SegmentLoadWaiterTest +public class SegmentLoadStatusFetcherTest { private static final String TEST_DATASOURCE = "testDatasource"; - private SegmentLoadWaiter segmentLoadWaiter; + private SegmentLoadStatusFetcher segmentLoadWaiter; private BrokerClient brokerClient; @@ -55,15 +56,30 @@ public class SegmentLoadWaiterTest doAnswer(new Answer() { int timesInvoked = 0; + @Override public String answer(InvocationOnMock invocation) throws Throwable { timesInvoked += 1; - SegmentLoadWaiter.VersionLoadStatus loadStatus = new SegmentLoadWaiter.VersionLoadStatus(5, timesInvoked, 0, 5 - timesInvoked, 0); + SegmentLoadStatusFetcher.VersionLoadStatus loadStatus = new SegmentLoadStatusFetcher.VersionLoadStatus( + 5, + timesInvoked, + 0, + 5 - timesInvoked, + 0 + ); return new ObjectMapper().writeValueAsString(loadStatus); } }).when(brokerClient).sendQuery(any()); - segmentLoadWaiter = new SegmentLoadWaiter(brokerClient, new ObjectMapper(), TEST_DATASOURCE, ImmutableSet.of("version1"), 5, false); + segmentLoadWaiter = new SegmentLoadStatusFetcher( + brokerClient, + new ObjectMapper(), + "id", + TEST_DATASOURCE, + ImmutableSet.of("version1"), + 5, + false + ); segmentLoadWaiter.waitForSegmentsToLoad(); verify(brokerClient, times(5)).sendQuery(any()); @@ -78,18 +94,79 @@ public class SegmentLoadWaiterTest doAnswer(new Answer() { int timesInvoked = 0; + @Override public String answer(InvocationOnMock invocation) throws Throwable { timesInvoked += 1; - SegmentLoadWaiter.VersionLoadStatus loadStatus = new SegmentLoadWaiter.VersionLoadStatus(5, timesInvoked, 0, 5 - timesInvoked, 0); + SegmentLoadStatusFetcher.VersionLoadStatus loadStatus = new SegmentLoadStatusFetcher.VersionLoadStatus( + 5, + timesInvoked, + 0, + 5 - timesInvoked, + 0 + ); return new ObjectMapper().writeValueAsString(loadStatus); } }).when(brokerClient).sendQuery(any()); - segmentLoadWaiter = new SegmentLoadWaiter(brokerClient, new ObjectMapper(), TEST_DATASOURCE, ImmutableSet.of("version1"), 5, false); + segmentLoadWaiter = new SegmentLoadStatusFetcher( + brokerClient, + new ObjectMapper(), + "id", + TEST_DATASOURCE, + ImmutableSet.of("version1"), + 5, + false + ); segmentLoadWaiter.waitForSegmentsToLoad(); verify(brokerClient, times(5)).sendQuery(any()); } + @Test + public void triggerCancellationFromAnotherThread() throws Exception + { + brokerClient = mock(BrokerClient.class); + doReturn(mock(Request.class)).when(brokerClient).makeRequest(any(), anyString()); + doAnswer(new Answer() + { + int timesInvoked = 0; + + @Override + public String answer(InvocationOnMock invocation) throws Throwable + { + // sleeping broker call to simulate a long running query + Thread.sleep(1000); + timesInvoked++; + SegmentLoadStatusFetcher.VersionLoadStatus loadStatus = new SegmentLoadStatusFetcher.VersionLoadStatus( + 5, + timesInvoked, + 0, + 5 - timesInvoked, + 0 + ); + return new ObjectMapper().writeValueAsString(loadStatus); + } + }).when(brokerClient).sendQuery(any()); + segmentLoadWaiter = new SegmentLoadStatusFetcher( + brokerClient, + new ObjectMapper(), + "id", + TEST_DATASOURCE, + ImmutableSet.of("version1"), + 5, + true + ); + + Thread t = new Thread(() -> segmentLoadWaiter.waitForSegmentsToLoad()); + t.start(); + // call close from main thread + segmentLoadWaiter.close(); + t.join(1000); + Assert.assertFalse(t.isAlive()); + + Assert.assertTrue(segmentLoadWaiter.status().getState().isFinished()); + Assert.assertTrue(segmentLoadWaiter.status().getState() == SegmentLoadStatusFetcher.State.FAILED); + } + } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java index ef50008d48e..3b49572996e 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java @@ -35,7 +35,7 @@ import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.msq.counters.CounterSnapshotsTree; -import org.apache.druid.msq.exec.SegmentLoadWaiter; +import org.apache.druid.msq.exec.SegmentLoadStatusFetcher; import org.apache.druid.msq.guice.MSQIndexingModule; import org.apache.druid.msq.indexing.error.MSQErrorReport; import org.apache.druid.msq.indexing.error.TooManyColumnsFault; @@ -92,8 +92,8 @@ public class MSQTaskReportTest new Object[]{"bar"} ); - SegmentLoadWaiter.SegmentLoadWaiterStatus status = new SegmentLoadWaiter.SegmentLoadWaiterStatus( - SegmentLoadWaiter.State.WAITING, + SegmentLoadStatusFetcher.SegmentLoadWaiterStatus status = new SegmentLoadStatusFetcher.SegmentLoadWaiterStatus( + SegmentLoadStatusFetcher.State.WAITING, DateTimes.nowUtc(), 200L, 100, @@ -156,8 +156,8 @@ public class MSQTaskReportTest @Test public void testSerdeErrorReport() throws Exception { - SegmentLoadWaiter.SegmentLoadWaiterStatus status = new SegmentLoadWaiter.SegmentLoadWaiterStatus( - SegmentLoadWaiter.State.FAILED, + SegmentLoadStatusFetcher.SegmentLoadWaiterStatus status = new SegmentLoadStatusFetcher.SegmentLoadWaiterStatus( + SegmentLoadStatusFetcher.State.FAILED, DateTimes.nowUtc(), 200L, 100, @@ -205,8 +205,8 @@ public class MSQTaskReportTest @Test public void testWriteTaskReport() throws Exception { - SegmentLoadWaiter.SegmentLoadWaiterStatus status = new SegmentLoadWaiter.SegmentLoadWaiterStatus( - SegmentLoadWaiter.State.SUCCESS, + SegmentLoadStatusFetcher.SegmentLoadWaiterStatus status = new SegmentLoadStatusFetcher.SegmentLoadWaiterStatus( + SegmentLoadStatusFetcher.State.SUCCESS, DateTimes.nowUtc(), 200L, 100,