Allow cancellation of MSQ tasks if they are waiting for segments to load (#15000)

With PR #14322 , MSQ insert/Replace q's will wait for segment to be loaded on the historical's before finishing.

The patch introduces a bug where in the main thread had a thread.sleep() which could not be interrupted via the cancel calls from the overlord.

This new patch addressed that problem by moving the thread.sleep inside a thread of its own. Thus the main thread is now waiting on the future object of this execution.

The cancel call can now shutdown the executor service via another method thus unblocking the main thread to proceed.
This commit is contained in:
Karan Kumar 2023-09-22 11:21:04 +05:30 committed by GitHub
parent 409bffe7f2
commit 5cee9f6148
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 219 additions and 73 deletions

View File

@ -294,7 +294,7 @@ public class ControllerImpl implements Controller
private WorkerMemoryParameters workerMemoryParameters;
private boolean isDurableStorageEnabled;
private boolean isFaultToleranceEnabled;
private volatile SegmentLoadWaiter segmentLoadWaiter;
private volatile SegmentLoadStatusFetcher segmentLoadWaiter;
public ControllerImpl(
final MSQControllerTask task,
@ -354,6 +354,7 @@ public class ControllerImpl implements Controller
// stopGracefully() is called when the containing process is terminated, or when the task is canceled.
log.info("Query [%s] canceled.", queryDef != null ? queryDef.getQueryId() : "<no id yet>");
stopExternalFetchers();
addToKernelManipulationQueue(
kernel -> {
throw new MSQException(CanceledFault.INSTANCE);
@ -465,7 +466,6 @@ public class ControllerImpl implements Controller
try {
releaseTaskLocks();
cleanUpDurableStorageIfNeeded();
if (queryKernel != null && queryKernel.isSuccess()) {
@ -474,6 +474,7 @@ public class ControllerImpl implements Controller
segmentLoadWaiter.waitForSegmentsToLoad();
}
}
stopExternalFetchers();
}
catch (Exception e) {
log.warn(e, "Exception thrown during cleanup. Ignoring it and writing task report.");
@ -742,7 +743,7 @@ public class ControllerImpl implements Controller
/**
* Accepts a {@link PartialKeyStatisticsInformation} and updates the controller key statistics information. If all key
* statistics information has been gathered, enqueues the task with the {@link WorkerSketchFetcher} to generate
* partiton boundaries. This is intended to be called by the {@link ControllerChatHandler}.
* partition boundaries. This is intended to be called by the {@link ControllerChatHandler}.
*/
@Override
public void updatePartialKeyStatisticsInformation(
@ -801,7 +802,7 @@ public class ControllerImpl implements Controller
/**
* This method intakes all the warnings that are generated by the worker. It is the responsibility of the
* worker node to ensure that it doesn't spam the controller with unneseccary warning stack traces. Currently, that
* worker node to ensure that it doesn't spam the controller with unnecessary warning stack traces. Currently, that
* limiting is implemented in {@link MSQWarningReportLimiterPublisher}
*/
@Override
@ -1360,9 +1361,10 @@ public class ControllerImpl implements Controller
}
} else {
Set<String> versionsToAwait = segmentsWithTombstones.stream().map(DataSegment::getVersion).collect(Collectors.toSet());
segmentLoadWaiter = new SegmentLoadWaiter(
segmentLoadWaiter = new SegmentLoadStatusFetcher(
context.injector().getInstance(BrokerClient.class),
context.jsonMapper(),
task.getId(),
task.getDataSource(),
versionsToAwait,
segmentsWithTombstones.size(),
@ -1375,9 +1377,10 @@ public class ControllerImpl implements Controller
}
} else if (!segments.isEmpty()) {
Set<String> versionsToAwait = segments.stream().map(DataSegment::getVersion).collect(Collectors.toSet());
segmentLoadWaiter = new SegmentLoadWaiter(
segmentLoadWaiter = new SegmentLoadStatusFetcher(
context.injector().getInstance(BrokerClient.class),
context.jsonMapper(),
task.getId(),
task.getDataSource(),
versionsToAwait,
segments.size(),
@ -2129,7 +2132,7 @@ public class ControllerImpl implements Controller
@Nullable final DateTime queryStartTime,
final long queryDuration,
MSQWorkerTaskLauncher taskLauncher,
final SegmentLoadWaiter segmentLoadWaiter
final SegmentLoadStatusFetcher segmentLoadWaiter
)
{
int pendingTasks = -1;
@ -2141,7 +2144,7 @@ public class ControllerImpl implements Controller
runningTasks = workerTaskCount.getRunningWorkerCount() + 1; // To account for controller.
}
SegmentLoadWaiter.SegmentLoadWaiterStatus status = segmentLoadWaiter == null ? null : segmentLoadWaiter.status();
SegmentLoadStatusFetcher.SegmentLoadWaiterStatus status = segmentLoadWaiter == null ? null : segmentLoadWaiter.status();
return new MSQStatusReport(
taskState,
@ -2260,6 +2263,16 @@ public class ControllerImpl implements Controller
}
}
private void stopExternalFetchers()
{
if (workerSketchFetcher != null) {
workerSketchFetcher.close();
}
if (segmentLoadWaiter != null) {
segmentLoadWaiter.close();
}
}
/**
* Main controller logic for running a multi-stage query.
*/

View File

@ -24,9 +24,13 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import org.apache.druid.common.guava.FutureUtils;
import org.apache.druid.discovery.BrokerClient;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.java.util.http.client.Request;
import org.apache.druid.sql.http.ResultFormat;
@ -56,9 +60,9 @@ import java.util.concurrent.atomic.AtomicReference;
* If the segments are not loaded within {@link #TIMEOUT_DURATION_MILLIS} milliseconds, this logs a warning and exits
* for the same reason.
*/
public class SegmentLoadWaiter
public class SegmentLoadStatusFetcher implements AutoCloseable
{
private static final Logger log = new Logger(SegmentLoadWaiter.class);
private static final Logger log = new Logger(SegmentLoadStatusFetcher.class);
private static final long SLEEP_DURATION_MILLIS = TimeUnit.SECONDS.toMillis(5);
private static final long TIMEOUT_DURATION_MILLIS = TimeUnit.MINUTES.toMillis(10);
@ -68,9 +72,9 @@ public class SegmentLoadWaiter
* - If a segment is not used, the broker will not have any information about it, hence, a COUNT(*) should return the used count only.
* - If replication_factor is more than 0, the segment will be loaded on historicals and needs to be waited for.
* - If replication_factor is 0, that means that the segment will never be loaded on a historical and does not need to
* be waited for.
* be waited for.
* - If replication_factor is -1, the replication factor is not known currently and will become known after a load rule
* evaluation.
* evaluation.
* <br>
* See https://github.com/apache/druid/pull/14403 for more details about replication_factor
*/
@ -90,11 +94,15 @@ public class SegmentLoadWaiter
private final Set<String> versionsToAwait;
private final int totalSegmentsGenerated;
private final boolean doWait;
// since live reports fetch the value in another thread, we need to use AtomicReference
private final AtomicReference<SegmentLoadWaiterStatus> status;
public SegmentLoadWaiter(
private final ListeningExecutorService executorService;
public SegmentLoadStatusFetcher(
BrokerClient brokerClient,
ObjectMapper objectMapper,
String taskId,
String datasource,
Set<String> versionsToAwait,
int totalSegmentsGenerated,
@ -107,8 +115,19 @@ public class SegmentLoadWaiter
this.versionsToAwait = new TreeSet<>(versionsToAwait);
this.versionToLoadStatusMap = new HashMap<>();
this.totalSegmentsGenerated = totalSegmentsGenerated;
this.status = new AtomicReference<>(new SegmentLoadWaiterStatus(State.INIT, null, 0, totalSegmentsGenerated, 0, 0, 0, 0, totalSegmentsGenerated));
this.status = new AtomicReference<>(new SegmentLoadWaiterStatus(
State.INIT,
null,
0,
totalSegmentsGenerated,
0,
0,
0,
0,
totalSegmentsGenerated
));
this.doWait = doWait;
this.executorService = MoreExecutors.listeningDecorator(Execs.singleThreaded(taskId + "-segment-load-waiter-%d"));
}
/**
@ -122,57 +141,73 @@ public class SegmentLoadWaiter
*/
public void waitForSegmentsToLoad()
{
DateTime startTime = DateTimes.nowUtc();
boolean hasAnySegmentBeenLoaded = false;
final DateTime startTime = DateTimes.nowUtc();
final AtomicReference<Boolean> hasAnySegmentBeenLoaded = new AtomicReference<>(false);
try {
while (!versionsToAwait.isEmpty()) {
// Check the timeout and exit if exceeded.
long runningMillis = new Interval(startTime, DateTimes.nowUtc()).toDurationMillis();
if (runningMillis > TIMEOUT_DURATION_MILLIS) {
log.warn("Runtime [%s] exceeded timeout [%s] while waiting for segments to load. Exiting.", runningMillis, TIMEOUT_DURATION_MILLIS);
updateStatus(State.TIMED_OUT, startTime);
return;
}
FutureUtils.getUnchecked(executorService.submit(() -> {
try {
while (!versionsToAwait.isEmpty()) {
// Check the timeout and exit if exceeded.
long runningMillis = new Interval(startTime, DateTimes.nowUtc()).toDurationMillis();
if (runningMillis > TIMEOUT_DURATION_MILLIS) {
log.warn(
"Runtime[%d] exceeded timeout[%d] while waiting for segments to load. Exiting.",
runningMillis,
TIMEOUT_DURATION_MILLIS
);
updateStatus(State.TIMED_OUT, startTime);
return;
}
Iterator<String> iterator = versionsToAwait.iterator();
Iterator<String> iterator = versionsToAwait.iterator();
log.info(
"Fetching segment load status for datasource[%s] from broker for segment versions[%s]",
datasource,
versionsToAwait
);
// Query the broker for all pending versions
while (iterator.hasNext()) {
String version = iterator.next();
// Query the broker for all pending versions
while (iterator.hasNext()) {
String version = iterator.next();
// Fetch the load status for this version from the broker
VersionLoadStatus loadStatus = fetchLoadStatusForVersion(version);
versionToLoadStatusMap.put(version, loadStatus);
// Fetch the load status for this version from the broker
VersionLoadStatus loadStatus = fetchLoadStatusForVersion(version);
versionToLoadStatusMap.put(version, loadStatus);
hasAnySegmentBeenLoaded.set(hasAnySegmentBeenLoaded.get() || loadStatus.getUsedSegments() > 0);
hasAnySegmentBeenLoaded = hasAnySegmentBeenLoaded || loadStatus.getUsedSegments() > 0;
// If loading is done for this stage, remove it from future loops.
if (hasAnySegmentBeenLoaded.get() && loadStatus.isLoadingComplete()) {
iterator.remove();
}
}
// If loading is done for this stage, remove it from future loops.
if (hasAnySegmentBeenLoaded && loadStatus.isLoadingComplete()) {
iterator.remove();
if (!versionsToAwait.isEmpty()) {
// Update the status.
updateStatus(State.WAITING, startTime);
// Sleep for a bit before checking again.
waitIfNeeded(SLEEP_DURATION_MILLIS);
}
}
}
if (!versionsToAwait.isEmpty()) {
// Update the status.
updateStatus(State.WAITING, startTime);
// Sleep for a while before retrying.
waitIfNeeded(SLEEP_DURATION_MILLIS);
catch (Exception e) {
log.warn(e, "Exception occurred while waiting for segments to load. Exiting.");
// Update the status and return.
updateStatus(State.FAILED, startTime);
return;
}
}
// Update the status.
log.info("Segment loading completed for datasource[%s]", datasource);
updateStatus(State.SUCCESS, startTime);
}), true);
}
catch (Exception e) {
log.warn(e, "Exception occurred while waiting for segments to load. Exiting.");
// Update the status and return.
updateStatus(State.FAILED, startTime);
return;
}
// Update the status.
updateStatus(State.SUCCESS, startTime);
finally {
executorService.shutdownNow();
}
}
private void waitIfNeeded(long waitTimeMillis) throws Exception
{
if (doWait) {
@ -219,9 +254,9 @@ public class SegmentLoadWaiter
Request request = brokerClient.makeRequest(HttpMethod.POST, "/druid/v2/sql/");
SqlQuery sqlQuery = new SqlQuery(StringUtils.format(LOAD_QUERY, datasource, version),
ResultFormat.OBJECTLINES,
false, false, false, null, null);
false, false, false, null, null
);
request.setContent(MediaType.APPLICATION_JSON, objectMapper.writeValueAsBytes(sqlQuery));
String response = brokerClient.sendQuery(request);
if (response.trim().isEmpty()) {
@ -240,6 +275,17 @@ public class SegmentLoadWaiter
return status.get();
}
@Override
public void close()
{
try {
executorService.shutdownNow();
}
catch (Throwable suppressed) {
log.warn(suppressed, "Error shutting down SegmentLoadStatusFetcher");
}
}
public static class SegmentLoadWaiterStatus
{
private final State state;
@ -254,7 +300,7 @@ public class SegmentLoadWaiter
@JsonCreator
public SegmentLoadWaiterStatus(
@JsonProperty("state") SegmentLoadWaiter.State state,
@JsonProperty("state") SegmentLoadStatusFetcher.State state,
@JsonProperty("startTime") @Nullable DateTime startTime,
@JsonProperty("duration") long duration,
@JsonProperty("totalSegments") int totalSegments,
@ -277,7 +323,7 @@ public class SegmentLoadWaiter
}
@JsonProperty
public SegmentLoadWaiter.State getState()
public SegmentLoadStatusFetcher.State getState()
{
return state;
}
@ -356,7 +402,12 @@ public class SegmentLoadWaiter
* The time spent waiting for segments to load exceeded org.apache.druid.msq.exec.SegmentLoadWaiter#TIMEOUT_DURATION_MILLIS.
* The SegmentLoadWaiter exited without failing the task.
*/
TIMED_OUT
TIMED_OUT;
public boolean isFinished()
{
return this == SUCCESS || this == FAILED || this == TIMED_OUT;
}
}
public static class VersionLoadStatus

View File

@ -304,6 +304,11 @@ public class WorkerSketchFetcher implements AutoCloseable
@Override
public void close()
{
executorService.shutdownNow();
try {
executorService.shutdownNow();
}
catch (Throwable suppressed) {
log.warn(suppressed, "Error while shutting down WorkerSketchFetcher");
}
}
}

View File

@ -24,7 +24,7 @@ import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import org.apache.druid.indexer.TaskState;
import org.apache.druid.msq.exec.SegmentLoadWaiter;
import org.apache.druid.msq.exec.SegmentLoadStatusFetcher;
import org.apache.druid.msq.indexing.error.MSQErrorReport;
import org.joda.time.DateTime;
@ -52,7 +52,7 @@ public class MSQStatusReport
private final int runningTasks;
@Nullable
private final SegmentLoadWaiter.SegmentLoadWaiterStatus segmentLoadWaiterStatus;
private final SegmentLoadStatusFetcher.SegmentLoadWaiterStatus segmentLoadWaiterStatus;
@JsonCreator
public MSQStatusReport(
@ -63,7 +63,7 @@ public class MSQStatusReport
@JsonProperty("durationMs") long durationMs,
@JsonProperty("pendingTasks") int pendingTasks,
@JsonProperty("runningTasks") int runningTasks,
@JsonProperty("segmentLoadWaiterStatus") @Nullable SegmentLoadWaiter.SegmentLoadWaiterStatus segmentLoadWaiterStatus
@JsonProperty("segmentLoadWaiterStatus") @Nullable SegmentLoadStatusFetcher.SegmentLoadWaiterStatus segmentLoadWaiterStatus
)
{
this.status = Preconditions.checkNotNull(status, "status");
@ -126,7 +126,7 @@ public class MSQStatusReport
@Nullable
@JsonProperty
@JsonInclude(JsonInclude.Include.NON_NULL)
public SegmentLoadWaiter.SegmentLoadWaiterStatus getSegmentLoadWaiterStatus()
public SegmentLoadStatusFetcher.SegmentLoadWaiterStatus getSegmentLoadWaiterStatus()
{
return segmentLoadWaiterStatus;
}

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.discovery.BrokerClient;
import org.apache.druid.java.util.http.client.Request;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
@ -35,11 +36,11 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
public class SegmentLoadWaiterTest
public class SegmentLoadStatusFetcherTest
{
private static final String TEST_DATASOURCE = "testDatasource";
private SegmentLoadWaiter segmentLoadWaiter;
private SegmentLoadStatusFetcher segmentLoadWaiter;
private BrokerClient brokerClient;
@ -55,15 +56,30 @@ public class SegmentLoadWaiterTest
doAnswer(new Answer<String>()
{
int timesInvoked = 0;
@Override
public String answer(InvocationOnMock invocation) throws Throwable
{
timesInvoked += 1;
SegmentLoadWaiter.VersionLoadStatus loadStatus = new SegmentLoadWaiter.VersionLoadStatus(5, timesInvoked, 0, 5 - timesInvoked, 0);
SegmentLoadStatusFetcher.VersionLoadStatus loadStatus = new SegmentLoadStatusFetcher.VersionLoadStatus(
5,
timesInvoked,
0,
5 - timesInvoked,
0
);
return new ObjectMapper().writeValueAsString(loadStatus);
}
}).when(brokerClient).sendQuery(any());
segmentLoadWaiter = new SegmentLoadWaiter(brokerClient, new ObjectMapper(), TEST_DATASOURCE, ImmutableSet.of("version1"), 5, false);
segmentLoadWaiter = new SegmentLoadStatusFetcher(
brokerClient,
new ObjectMapper(),
"id",
TEST_DATASOURCE,
ImmutableSet.of("version1"),
5,
false
);
segmentLoadWaiter.waitForSegmentsToLoad();
verify(brokerClient, times(5)).sendQuery(any());
@ -78,18 +94,79 @@ public class SegmentLoadWaiterTest
doAnswer(new Answer<String>()
{
int timesInvoked = 0;
@Override
public String answer(InvocationOnMock invocation) throws Throwable
{
timesInvoked += 1;
SegmentLoadWaiter.VersionLoadStatus loadStatus = new SegmentLoadWaiter.VersionLoadStatus(5, timesInvoked, 0, 5 - timesInvoked, 0);
SegmentLoadStatusFetcher.VersionLoadStatus loadStatus = new SegmentLoadStatusFetcher.VersionLoadStatus(
5,
timesInvoked,
0,
5 - timesInvoked,
0
);
return new ObjectMapper().writeValueAsString(loadStatus);
}
}).when(brokerClient).sendQuery(any());
segmentLoadWaiter = new SegmentLoadWaiter(brokerClient, new ObjectMapper(), TEST_DATASOURCE, ImmutableSet.of("version1"), 5, false);
segmentLoadWaiter = new SegmentLoadStatusFetcher(
brokerClient,
new ObjectMapper(),
"id",
TEST_DATASOURCE,
ImmutableSet.of("version1"),
5,
false
);
segmentLoadWaiter.waitForSegmentsToLoad();
verify(brokerClient, times(5)).sendQuery(any());
}
@Test
public void triggerCancellationFromAnotherThread() throws Exception
{
brokerClient = mock(BrokerClient.class);
doReturn(mock(Request.class)).when(brokerClient).makeRequest(any(), anyString());
doAnswer(new Answer<String>()
{
int timesInvoked = 0;
@Override
public String answer(InvocationOnMock invocation) throws Throwable
{
// sleeping broker call to simulate a long running query
Thread.sleep(1000);
timesInvoked++;
SegmentLoadStatusFetcher.VersionLoadStatus loadStatus = new SegmentLoadStatusFetcher.VersionLoadStatus(
5,
timesInvoked,
0,
5 - timesInvoked,
0
);
return new ObjectMapper().writeValueAsString(loadStatus);
}
}).when(brokerClient).sendQuery(any());
segmentLoadWaiter = new SegmentLoadStatusFetcher(
brokerClient,
new ObjectMapper(),
"id",
TEST_DATASOURCE,
ImmutableSet.of("version1"),
5,
true
);
Thread t = new Thread(() -> segmentLoadWaiter.waitForSegmentsToLoad());
t.start();
// call close from main thread
segmentLoadWaiter.close();
t.join(1000);
Assert.assertFalse(t.isAlive());
Assert.assertTrue(segmentLoadWaiter.status().getState().isFinished());
Assert.assertTrue(segmentLoadWaiter.status().getState() == SegmentLoadStatusFetcher.State.FAILED);
}
}

View File

@ -35,7 +35,7 @@ import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.guava.Yielder;
import org.apache.druid.java.util.common.guava.Yielders;
import org.apache.druid.msq.counters.CounterSnapshotsTree;
import org.apache.druid.msq.exec.SegmentLoadWaiter;
import org.apache.druid.msq.exec.SegmentLoadStatusFetcher;
import org.apache.druid.msq.guice.MSQIndexingModule;
import org.apache.druid.msq.indexing.error.MSQErrorReport;
import org.apache.druid.msq.indexing.error.TooManyColumnsFault;
@ -92,8 +92,8 @@ public class MSQTaskReportTest
new Object[]{"bar"}
);
SegmentLoadWaiter.SegmentLoadWaiterStatus status = new SegmentLoadWaiter.SegmentLoadWaiterStatus(
SegmentLoadWaiter.State.WAITING,
SegmentLoadStatusFetcher.SegmentLoadWaiterStatus status = new SegmentLoadStatusFetcher.SegmentLoadWaiterStatus(
SegmentLoadStatusFetcher.State.WAITING,
DateTimes.nowUtc(),
200L,
100,
@ -156,8 +156,8 @@ public class MSQTaskReportTest
@Test
public void testSerdeErrorReport() throws Exception
{
SegmentLoadWaiter.SegmentLoadWaiterStatus status = new SegmentLoadWaiter.SegmentLoadWaiterStatus(
SegmentLoadWaiter.State.FAILED,
SegmentLoadStatusFetcher.SegmentLoadWaiterStatus status = new SegmentLoadStatusFetcher.SegmentLoadWaiterStatus(
SegmentLoadStatusFetcher.State.FAILED,
DateTimes.nowUtc(),
200L,
100,
@ -205,8 +205,8 @@ public class MSQTaskReportTest
@Test
public void testWriteTaskReport() throws Exception
{
SegmentLoadWaiter.SegmentLoadWaiterStatus status = new SegmentLoadWaiter.SegmentLoadWaiterStatus(
SegmentLoadWaiter.State.SUCCESS,
SegmentLoadStatusFetcher.SegmentLoadWaiterStatus status = new SegmentLoadStatusFetcher.SegmentLoadWaiterStatus(
SegmentLoadStatusFetcher.State.SUCCESS,
DateTimes.nowUtc(),
200L,
100,