RemoteTaskRunner: Fix NPE in streamTaskReports. (#12006)

* RemoteTaskRunner: Fix NPE in streamTaskReports.

It is possible for a work item to drop out of runningTasks after the
ZkWorker is retrieved. In this case, the current code would throw
an NPE.

* Additional tests and additional fixes.

* Fix import.
This commit is contained in:
Gian Merlino 2022-05-19 14:23:55 -07:00 committed by GitHub
parent 65a1375b67
commit 5f95cc61fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 206 additions and 44 deletions

View File

@ -660,37 +660,50 @@ public class RemoteTaskRunner implements WorkerTaskRunner, TaskLogStreamer
if (zkWorker == null) {
// Worker is not running this task, it might be available in deep storage
return Optional.absent();
} else {
TaskLocation taskLocation = runningTasks.get(taskId).getLocation();
final URL url = TaskRunnerUtils.makeTaskLocationURL(
taskLocation,
"/druid/worker/v1/chat/%s/liveReports",
taskId
);
return Optional.of(
new ByteSource()
}
final RemoteTaskRunnerWorkItem runningWorkItem = runningTasks.get(taskId);
if (runningWorkItem == null) {
// Worker very recently exited.
return Optional.absent();
}
final TaskLocation taskLocation = runningWorkItem.getLocation();
if (TaskLocation.unknown().equals(taskLocation)) {
// No location known for this task. It may have not been assigned one yet.
return Optional.absent();
}
final URL url = TaskRunnerUtils.makeTaskLocationURL(
taskLocation,
"/druid/worker/v1/chat/%s/liveReports",
taskId
);
return Optional.of(
new ByteSource()
{
@Override
public InputStream openStream() throws IOException
{
@Override
public InputStream openStream() throws IOException
{
try {
return httpClient.go(
new Request(HttpMethod.GET, url),
new InputStreamResponseHandler()
).get();
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
catch (ExecutionException e) {
// Unwrap if possible
Throwables.propagateIfPossible(e.getCause(), IOException.class);
throw new RuntimeException(e);
}
try {
return httpClient.go(
new Request(HttpMethod.GET, url),
new InputStreamResponseHandler()
).get();
}
catch (InterruptedException e) {
throw new RuntimeException(e);
}
catch (ExecutionException e) {
// Unwrap if possible
Throwables.propagateIfPossible(e.getCause(), IOException.class);
throw new RuntimeException(e);
}
}
);
}
}
);
}
/**

View File

@ -1024,6 +1024,12 @@ public class HttpRemoteTaskRunner implements WorkerTaskRunner, TaskLogStreamer
} else {
// Worker is still running this task
TaskLocation taskLocation = taskRunnerWorkItem.getLocation();
if (TaskLocation.unknown().equals(taskLocation)) {
// No location known for this task. It may have not been assigned a location yet.
return Optional.absent();
}
final URL url = TaskRunnerUtils.makeTaskLocationURL(
taskLocation,
"/druid/worker/v1/chat/%s/liveReports",

View File

@ -19,12 +19,16 @@
package org.apache.druid.indexing.common.tasklogs;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.ByteStreams;
import com.google.common.io.Files;
import org.apache.druid.indexing.common.TaskReport;
import org.apache.druid.indexing.common.config.FileTaskLogsConfig;
import org.apache.druid.java.util.common.FileUtils;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.tasklogs.TaskLogs;
import org.junit.Assert;
import org.junit.Rule;
@ -69,6 +73,28 @@ public class FileTaskLogsTest
}
}
@Test
public void testSimpleReport() throws Exception
{
final ObjectMapper mapper = TestHelper.makeJsonMapper();
final File tmpDir = temporaryFolder.newFolder();
final File logDir = new File(tmpDir, "druid/logs");
final File reportFile = new File(tmpDir, "report.json");
final String taskId = "myTask";
final TestTaskReport testReport = new TestTaskReport(taskId);
final String testReportString = mapper.writeValueAsString(TaskReport.buildTaskReports(testReport));
Files.write(testReportString, reportFile, StandardCharsets.UTF_8);
final TaskLogs taskLogs = new FileTaskLogs(new FileTaskLogsConfig(logDir));
taskLogs.pushTaskReports("foo", reportFile);
Assert.assertEquals(
testReportString,
StringUtils.fromUtf8(ByteStreams.toByteArray(taskLogs.streamTaskReports("foo").get().openStream()))
);
}
@Test
public void testPushTaskLogDirCreationFails() throws Exception
{
@ -123,4 +149,37 @@ public class FileTaskLogsTest
{
return StringUtils.fromUtf8(ByteStreams.toByteArray(taskLogs.streamTaskLog(logFile, offset).get().openStream()));
}
private static class TestTaskReport implements TaskReport
{
static final String KEY = "testReport";
static final Map<String, Object> PAYLOAD = ImmutableMap.of("foo", "bar");
private final String taskId;
public TestTaskReport(String taskId)
{
this.taskId = taskId;
}
@Override
@JsonProperty
public String getTaskId()
{
return taskId;
}
@Override
public String getReportKey()
{
return KEY;
}
@Override
@JsonProperty
public Object getPayload()
{
return PAYLOAD;
}
}
}

View File

@ -73,9 +73,17 @@ public class OverlordBlinkLeadershipTest
public void testOverlordBlinkLeadership()
{
try {
RemoteTaskRunner remoteTaskRunner1 = rtrUtils.makeRemoteTaskRunner(remoteTaskRunnerConfig, resourceManagement);
RemoteTaskRunner remoteTaskRunner1 = rtrUtils.makeRemoteTaskRunner(
remoteTaskRunnerConfig,
resourceManagement,
null
);
remoteTaskRunner1.stop();
RemoteTaskRunner remoteTaskRunner2 = rtrUtils.makeRemoteTaskRunner(remoteTaskRunnerConfig, resourceManagement);
RemoteTaskRunner remoteTaskRunner2 = rtrUtils.makeRemoteTaskRunner(
remoteTaskRunnerConfig,
resourceManagement,
null
);
remoteTaskRunner2.stop();
}
catch (Exception e) {

View File

@ -69,7 +69,8 @@ public class RemoteTaskRunnerRunPendingTasksConcurrencyTest
{
return 2;
}
}
},
null
);
int numTasks = 6;

View File

@ -22,14 +22,18 @@ package org.apache.druid.indexing.overlord;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Function;
import com.google.common.base.Joiner;
import com.google.common.base.Optional;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import org.apache.curator.framework.CuratorFramework;
import org.apache.curator.framework.recipes.cache.PathChildrenCache;
import org.apache.druid.indexer.TaskLocation;
import org.apache.druid.indexer.TaskState;
import org.apache.druid.indexer.TaskStatus;
import org.apache.druid.indexing.common.IndexingServiceCondition;
@ -46,6 +50,8 @@ import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.java.util.emitter.EmittingLogger;
import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.java.util.http.client.HttpClient;
import org.apache.druid.java.util.http.client.Request;
import org.apache.druid.server.metrics.NoopServiceEmitter;
import org.apache.druid.testing.DeadlockDetectingTimeout;
import org.easymock.Capture;
@ -61,6 +67,9 @@ import org.junit.rules.TestWatcher;
import org.junit.runner.Description;
import org.mockito.Mockito;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;
@ -73,13 +82,17 @@ public class RemoteTaskRunnerTest
private static final Logger LOG = new Logger(RemoteTaskRunnerTest.class);
private static final Joiner JOINER = RemoteTaskRunnerTestUtils.JOINER;
private static final String WORKER_HOST = "worker";
private static final String ANNOUCEMENTS_PATH = JOINER.join(RemoteTaskRunnerTestUtils.ANNOUNCEMENTS_PATH, WORKER_HOST);
private static final String ANNOUCEMENTS_PATH = JOINER.join(
RemoteTaskRunnerTestUtils.ANNOUNCEMENTS_PATH,
WORKER_HOST
);
private static final String STATUS_PATH = JOINER.join(RemoteTaskRunnerTestUtils.STATUS_PATH, WORKER_HOST);
// higher timeout to reduce flakiness on CI pipeline
private static final Period TIMEOUT_PERIOD = Period.millis(30000);
private RemoteTaskRunner remoteTaskRunner;
private HttpClient httpClient;
private RemoteTaskRunnerTestUtils rtrTestUtils = new RemoteTaskRunnerTestUtils();
private ObjectMapper jsonMapper;
private CuratorFramework cf;
@ -401,7 +414,8 @@ public class RemoteTaskRunnerTest
new TaskResource("first", 1),
"foo",
TaskStatus.running("first"),
jsonMapper);
jsonMapper
);
remoteTaskRunner.run(task1);
Assert.assertTrue(taskAnnounced(task1.getId()));
mockWorkerRunningTask(task1);
@ -411,15 +425,17 @@ public class RemoteTaskRunnerTest
new TaskResource("task", 2),
"foo",
TaskStatus.running("task"),
jsonMapper);
jsonMapper
);
remoteTaskRunner.run(task);
TestRealtimeTask task2 = new TestRealtimeTask(
"second",
new TaskResource("second", 2),
"foo",
TaskStatus.running("second"),
jsonMapper);
jsonMapper
);
remoteTaskRunner.run(task2);
Assert.assertTrue(taskAnnounced(task2.getId()));
mockWorkerRunningTask(task2);
@ -449,7 +465,8 @@ public class RemoteTaskRunnerTest
new TaskResource("testTask", 2),
"foo",
TaskStatus.success("testTask"),
jsonMapper);
jsonMapper
);
remoteTaskRunner.run(task1);
Assert.assertTrue(taskAnnounced(task1.getId()));
mockWorkerRunningTask(task1);
@ -590,7 +607,8 @@ public class RemoteTaskRunnerTest
private void makeRemoteTaskRunner(RemoteTaskRunnerConfig config)
{
remoteTaskRunner = rtrTestUtils.makeRemoteTaskRunner(config);
httpClient = EasyMock.createMock(HttpClient.class);
remoteTaskRunner = rtrTestUtils.makeRemoteTaskRunner(config, httpClient);
}
private void makeWorker() throws Exception
@ -1022,7 +1040,10 @@ public class RemoteTaskRunnerTest
mockWorkerCompleteFailedTask(task3);
Assert.assertTrue(taskFuture3.get().isFailure());
Assert.assertEquals(1, remoteTaskRunner.getBlackListedWorkers().size());
Assert.assertEquals(3, remoteTaskRunner.getBlacklistedTaskSlotCount().get(WorkerConfig.DEFAULT_CATEGORY).longValue());
Assert.assertEquals(
3,
remoteTaskRunner.getBlacklistedTaskSlotCount().get(WorkerConfig.DEFAULT_CATEGORY).longValue()
);
mockWorkerCompleteSuccessfulTask(task2);
Assert.assertTrue(taskFuture2.get().isSuccess());
@ -1046,7 +1067,8 @@ public class RemoteTaskRunnerTest
PathChildrenCache cache = new PathChildrenCache(cf, "/test", true);
testStartWithNoWorker();
cache.getListenable().addListener(remoteTaskRunner.getStatusListener(worker, new ZkWorker(worker, cache, jsonMapper), null));
cache.getListenable()
.addListener(remoteTaskRunner.getStatusListener(worker, new ZkWorker(worker, cache, jsonMapper), null));
cache.start(PathChildrenCache.StartMode.POST_INITIALIZED_EVENT);
// Status listener will recieve event with null data
@ -1062,4 +1084,56 @@ public class RemoteTaskRunnerTest
Assert.assertNull(alertDataMap.get("znode"));
// Status listener should successfully completes without throwing exception
}
@Test
public void testStreamTaskReportsUnknownTask() throws Exception
{
doSetup();
Assert.assertEquals(Optional.absent(), remoteTaskRunner.streamTaskReports("foo"));
}
@Test
public void testStreamTaskReportsKnownTask() throws Exception
{
doSetup();
final Capture<Request> capturedRequest = Capture.newInstance();
final String reportString = "my report!";
final ByteArrayInputStream reportResponse = new ByteArrayInputStream(StringUtils.toUtf8(reportString));
EasyMock.expect(httpClient.go(EasyMock.capture(capturedRequest), EasyMock.anyObject()))
.andReturn(Futures.immediateFuture(reportResponse));
EasyMock.replay(httpClient);
ListenableFuture<TaskStatus> result = remoteTaskRunner.run(task);
Assert.assertTrue(taskAnnounced(task.getId()));
mockWorkerRunningTask(task);
// Wait for the task to have a known location.
Assert.assertTrue(
TestUtils.conditionValid(
() ->
!remoteTaskRunner.getRunningTasks().isEmpty()
&& !Iterables.getOnlyElement(remoteTaskRunner.getRunningTasks())
.getLocation()
.equals(TaskLocation.unknown())
)
);
// Stream task reports from a running task.
final InputStream in = remoteTaskRunner.streamTaskReports(task.getId()).get().openStream();
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
ByteStreams.copy(in, baos);
Assert.assertEquals(reportString, StringUtils.fromUtf8(baos.toByteArray()));
// Stream task reports from a completed task.
mockWorkerCompleteSuccessfulTask(task);
Assert.assertTrue(workerCompletedTask(result));
Assert.assertEquals(Optional.absent(), remoteTaskRunner.streamTaskReports(task.getId()));
// Verify the HTTP request.
EasyMock.verify(httpClient);
Assert.assertEquals(
"http://dummy:9000/druid/worker/v1/chat/task%20id%20with%20spaces/liveReports",
capturedRequest.getValue().getUrl().toString()
);
}
}

View File

@ -106,15 +106,16 @@ public class RemoteTaskRunnerTestUtils
testingCluster.stop();
}
RemoteTaskRunner makeRemoteTaskRunner(RemoteTaskRunnerConfig config)
RemoteTaskRunner makeRemoteTaskRunner(RemoteTaskRunnerConfig config, HttpClient httpClient)
{
NoopProvisioningStrategy<WorkerTaskRunner> resourceManagement = new NoopProvisioningStrategy<>();
return makeRemoteTaskRunner(config, resourceManagement);
return makeRemoteTaskRunner(config, resourceManagement, httpClient);
}
public RemoteTaskRunner makeRemoteTaskRunner(
RemoteTaskRunnerConfig config,
ProvisioningStrategy<WorkerTaskRunner> provisioningStrategy
ProvisioningStrategy<WorkerTaskRunner> provisioningStrategy,
HttpClient httpClient
)
{
RemoteTaskRunner remoteTaskRunner = new TestableRemoteTaskRunner(
@ -132,7 +133,7 @@ public class RemoteTaskRunnerTestUtils
),
cf,
new PathChildrenCacheFactory.Builder(),
null,
httpClient,
DSuppliers.of(new AtomicReference<>(DefaultWorkerBehaviorConfig.defaultConfig())),
provisioningStrategy
);