Remove stray reference to fix OOM while merging sketches (#13475)

* Remove stray reference to fix OOM while merging sketches

* Update future to add result from executor service

* Update tests and address review comments

* Address review comments

* Moved mock

* Close threadpool on teardown

* Remove worker task cancel
This commit is contained in:
Adarsh Sanjeev 2022-12-08 07:17:55 +05:30 committed by GitHub
parent 69951273b8
commit fbf76ad8f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 116 deletions

View File

@ -59,7 +59,11 @@ public class WorkerSketchFetcher implements AutoCloseable
private final WorkerClient workerClient;
private final ExecutorService executorService;
public WorkerSketchFetcher(WorkerClient workerClient, ClusterStatisticsMergeMode clusterStatisticsMergeMode, int statisticsMaxRetainedBytes)
public WorkerSketchFetcher(
WorkerClient workerClient,
ClusterStatisticsMergeMode clusterStatisticsMergeMode,
int statisticsMaxRetainedBytes
)
{
this.workerClient = workerClient;
this.clusterStatisticsMergeMode = clusterStatisticsMergeMode;
@ -86,14 +90,14 @@ public class WorkerSketchFetcher implements AutoCloseable
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
case AUTO:
if (clusterBy.getBucketByCount() == 0) {
log.debug("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
log.info("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
// If there is no time clustering, there is no scope for sequential merge
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
} else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD || completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
log.debug("Query [%s] AUTO mode: chose SEQUENTIAL mode to merge key statistics", stageDefinition.getId().getQueryId());
log.info("Query [%s] AUTO mode: chose SEQUENTIAL mode to merge key statistics", stageDefinition.getId().getQueryId());
return sequentialTimeChunkMerging(completeKeyStatisticsInformation, stageDefinition, workerTaskIds);
}
log.debug("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
log.info("Query [%s] AUTO mode: chose PARALLEL mode to merge key statistics", stageDefinition.getId().getQueryId());
return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
default:
throw new IllegalStateException("No fetching strategy found for mode: " + clusterStatisticsMergeMode);
@ -128,12 +132,6 @@ public class WorkerSketchFetcher implements AutoCloseable
stageDefinition.getId().getQueryId(),
stageDefinition.getStageNumber()
);
partitionFuture.whenComplete((result, exception) -> {
if (exception != null || (result != null && result.isError())) {
snapshotFuture.cancel(true);
}
});
try {
ClusterByStatisticsSnapshot clusterByStatisticsSnapshot = snapshotFuture.get();
if (clusterByStatisticsSnapshot == null) {
@ -151,12 +149,15 @@ public class WorkerSketchFetcher implements AutoCloseable
}
catch (Exception e) {
synchronized (mergedStatisticsCollector) {
partitionFuture.completeExceptionally(e);
mergedStatisticsCollector.clear();
if (!partitionFuture.isDone()) {
partitionFuture.completeExceptionally(e);
mergedStatisticsCollector.clear();
}
}
}
});
});
return partitionFuture;
}
@ -247,11 +248,6 @@ public class WorkerSketchFetcher implements AutoCloseable
stageDefinition.getStageNumber(),
timeChunk
);
partitionFuture.whenComplete((result, exception) -> {
if (exception != null || (result != null && result.isError())) {
snapshotFuture.cancel(true);
}
});
try {
ClusterByStatisticsSnapshot snapshotForTimeChunk = snapshotFuture.get();
@ -289,8 +285,10 @@ public class WorkerSketchFetcher implements AutoCloseable
}
catch (Exception e) {
synchronized (mergedStatisticsCollector) {
partitionFuture.completeExceptionally(e);
mergedStatisticsCollector.clear();
if (!partitionFuture.isDone()) {
partitionFuture.completeExceptionally(e);
mergedStatisticsCollector.clear();
}
}
}
});

View File

@ -23,7 +23,6 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions;
@ -46,7 +45,6 @@ import java.util.Queue;
import java.util.Set;
import java.util.SortedMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
@ -56,7 +54,6 @@ import static org.easymock.EasyMock.mock;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
@ -107,52 +104,8 @@ public class WorkerSketchFetcherTest
public void tearDown() throws Exception
{
mocks.close();
}
@Test
public void test_submitFetcherTask_parallelFetch_workerThrowsException_shouldCancelOtherTasks() throws Exception
{
// Store futures in a queue
final Queue<ListenableFuture<ClusterByStatisticsSnapshot>> futureQueue = new ConcurrentLinkedQueue<>();
final List<String> workerIds = ImmutableList.of("0", "1", "2", "3");
final CountDownLatch latch = new CountDownLatch(workerIds.size());
target = spy(new WorkerSketchFetcher(workerClient, ClusterStatisticsMergeMode.PARALLEL, 300_000_000));
// When fetching snapshots, return a mock and add future to queue
doAnswer(invocation -> {
ListenableFuture<ClusterByStatisticsSnapshot> snapshotListenableFuture =
spy(Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class)));
futureQueue.add(snapshotListenableFuture);
latch.countDown();
latch.await();
return snapshotListenableFuture;
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), anyInt());
// Cause a worker to fail instead of returning the result
doAnswer(invocation -> {
latch.countDown();
latch.await();
return Futures.immediateFailedFuture(new InterruptedException("interrupted"));
}).when(workerClient).fetchClusterByStatisticsSnapshot(eq("2"), any(), anyInt());
CompletableFuture<Either<Long, ClusterByPartitions>> eitherCompletableFuture = target.submitFetcherTask(
completeKeyStatisticsInformation,
workerIds,
stageDefinition
);
// Assert that the final result is failed and all other task futures are also cancelled.
Assert.assertThrows(CompletionException.class, eitherCompletableFuture::join);
Thread.sleep(1000);
Assert.assertTrue(eitherCompletableFuture.isCompletedExceptionally());
// Verify that the statistics collector was cleared due to the error.
verify(mergedClusterByStatisticsCollector1, times(1)).clear();
// Verify that other task futures were requested to be cancelled.
Assert.assertFalse(futureQueue.isEmpty());
for (ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture : futureQueue) {
verify(snapshotFuture, times(1)).cancel(eq(true));
if (target != null) {
target.close();
}
}
@ -194,54 +147,6 @@ public class WorkerSketchFetcherTest
Assert.assertEquals(expectedPartitions1, eitherCompletableFuture.get().valueOrThrow());
}
@Test
public void test_submitFetcherTask_sequentialFetch_workerThrowsException_shouldCancelOtherTasks() throws Exception
{
// Store futures in a queue
final Queue<ListenableFuture<ClusterByStatisticsSnapshot>> futureQueue = new ConcurrentLinkedQueue<>();
SortedMap<Long, Set<Integer>> timeSegmentVsWorkerMap = ImmutableSortedMap.of(1L, ImmutableSet.of(0, 1, 2), 2L, ImmutableSet.of(0, 1, 4));
doReturn(timeSegmentVsWorkerMap).when(completeKeyStatisticsInformation).getTimeSegmentVsWorkerMap();
final CyclicBarrier barrier = new CyclicBarrier(3);
target = spy(new WorkerSketchFetcher(workerClient, ClusterStatisticsMergeMode.SEQUENTIAL, 300_000_000));
// When fetching snapshots, return a mock and add future to queue
doAnswer(invocation -> {
ListenableFuture<ClusterByStatisticsSnapshot> snapshotListenableFuture =
spy(Futures.immediateFuture(mock(ClusterByStatisticsSnapshot.class)));
futureQueue.add(snapshotListenableFuture);
barrier.await();
return snapshotListenableFuture;
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(anyString(), anyString(), anyInt(), anyLong());
// Cause a worker in the second time chunk to fail instead of returning the result
doAnswer(invocation -> {
barrier.await();
return Futures.immediateFailedFuture(new InterruptedException("interrupted"));
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(eq("4"), any(), anyInt(), eq(2L));
CompletableFuture<Either<Long, ClusterByPartitions>> eitherCompletableFuture = target.submitFetcherTask(
completeKeyStatisticsInformation,
ImmutableList.of("0", "1", "2", "3", "4"),
stageDefinition
);
// Assert that the final result is failed and all other task futures are also cancelled.
Assert.assertThrows(CompletionException.class, eitherCompletableFuture::join);
Thread.sleep(1000);
Assert.assertTrue(eitherCompletableFuture.isCompletedExceptionally());
// Verify that the correct statistics collector was cleared due to the error.
verify(mergedClusterByStatisticsCollector1, times(0)).clear();
verify(mergedClusterByStatisticsCollector2, times(1)).clear();
// Verify that other task futures were requested to be cancelled.
Assert.assertFalse(futureQueue.isEmpty());
for (ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture : futureQueue) {
verify(snapshotFuture, times(1)).cancel(eq(true));
}
}
@Test
public void test_submitFetcherTask_sequentialFetch_mergePerformedCorrectly()
throws ExecutionException, InterruptedException