RemoteTaskRunner: Fix NPE in streamTaskReports. (#12006)

* RemoteTaskRunner: Fix NPE in streamTaskReports.

It is possible for a work item to drop out of runningTasks after the
ZkWorker is retrieved. In this case, the current code would throw
an NPE.

* Additional tests and additional fixes.

* Fix import.
This commit is contained in:
Gian Merlino 2022-05-19 14:23:55 -07:00 committed by GitHub
parent 65a1375b67
commit 5f95cc61fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 206 additions and 44 deletions

View File

@ -660,37 +660,50 @@ public class RemoteTaskRunner implements WorkerTaskRunner, TaskLogStreamer
if (zkWorker == null) { if (zkWorker == null) {
// Worker is not running this task, it might be available in deep storage // Worker is not running this task, it might be available in deep storage
return Optional.absent(); return Optional.absent();
} else { }
TaskLocation taskLocation = runningTasks.get(taskId).getLocation();
final URL url = TaskRunnerUtils.makeTaskLocationURL( final RemoteTaskRunnerWorkItem runningWorkItem = runningTasks.get(taskId);
taskLocation,
"/druid/worker/v1/chat/%s/liveReports", if (runningWorkItem == null) {
taskId // Worker very recently exited.
); return Optional.absent();
return Optional.of( }
new ByteSource()
final TaskLocation taskLocation = runningWorkItem.getLocation();
if (TaskLocation.unknown().equals(taskLocation)) {
// No location known for this task. It may have not been assigned one yet.
return Optional.absent();
}
final URL url = TaskRunnerUtils.makeTaskLocationURL(
taskLocation,
"/druid/worker/v1/chat/%s/liveReports",
taskId
);
return Optional.of(
new ByteSource()
{
@Override
public InputStream openStream() throws IOException
{ {
@Override try {
public InputStream openStream() throws IOException return httpClient.go(
{ new Request(HttpMethod.GET, url),
try { new InputStreamResponseHandler()
return httpClient.go( ).get();
new Request(HttpMethod.GET, url), }
new InputStreamResponseHandler() catch (InterruptedException e) {
).get(); throw new RuntimeException(e);
} }
catch (InterruptedException e) { catch (ExecutionException e) {
throw new RuntimeException(e); // Unwrap if possible
} Throwables.propagateIfPossible(e.getCause(), IOException.class);
catch (ExecutionException e) { throw new RuntimeException(e);
// Unwrap if possible
Throwables.propagateIfPossible(e.getCause(), IOException.class);
throw new RuntimeException(e);
}
} }
} }
); }
} );
} }
/** /**

View File

@ -1024,6 +1024,12 @@ public class HttpRemoteTaskRunner implements WorkerTaskRunner, TaskLogStreamer
} else { } else {
// Worker is still running this task // Worker is still running this task
TaskLocation taskLocation = taskRunnerWorkItem.getLocation(); TaskLocation taskLocation = taskRunnerWorkItem.getLocation();
if (TaskLocation.unknown().equals(taskLocation)) {
// No location known for this task. It may have not been assigned a location yet.
return Optional.absent();
}
final URL url = TaskRunnerUtils.makeTaskLocationURL( final URL url = TaskRunnerUtils.makeTaskLocationURL(
taskLocation, taskLocation,
"/druid/worker/v1/chat/%s/liveReports", "/druid/worker/v1/chat/%s/liveReports",

View File

@ -19,12 +19,16 @@
package org.apache.druid.indexing.common.tasklogs; package org.apache.druid.indexing.common.tasklogs;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.io.ByteStreams; import com.google.common.io.ByteStreams;
import com.google.common.io.Files; import com.google.common.io.Files;
import org.apache.druid.indexing.common.TaskReport;
import org.apache.druid.indexing.common.config.FileTaskLogsConfig; import org.apache.druid.indexing.common.config.FileTaskLogsConfig;
import org.apache.druid.java.util.common.FileUtils; import org.apache.druid.java.util.common.FileUtils;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.tasklogs.TaskLogs; import org.apache.druid.tasklogs.TaskLogs;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Rule; import org.junit.Rule;
@ -69,6 +73,28 @@ public class FileTaskLogsTest
} }
} }
@Test
public void testSimpleReport() throws Exception
{
final ObjectMapper mapper = TestHelper.makeJsonMapper();
final File tmpDir = temporaryFolder.newFolder();
final File logDir = new File(tmpDir, "druid/logs");
final File reportFile = new File(tmpDir, "report.json");
final String taskId = "myTask";
final TestTaskReport testReport = new TestTaskReport(taskId);
final String testReportString = mapper.writeValueAsString(TaskReport.buildTaskReports(testReport));
Files.write(testReportString, reportFile, StandardCharsets.UTF_8);
final TaskLogs taskLogs = new FileTaskLogs(new FileTaskLogsConfig(logDir));
taskLogs.pushTaskReports("foo", reportFile);
Assert.assertEquals(
testReportString,
StringUtils.fromUtf8(ByteStreams.toByteArray(taskLogs.streamTaskReports("foo").get().openStream()))
);
}
@Test @Test
public void testPushTaskLogDirCreationFails() throws Exception public void testPushTaskLogDirCreationFails() throws Exception
{ {
@ -123,4 +149,37 @@ public class FileTaskLogsTest
{ {
return StringUtils.fromUtf8(ByteStreams.toByteArray(taskLogs.streamTaskLog(logFile, offset).get().openStream())); return StringUtils.fromUtf8(ByteStreams.toByteArray(taskLogs.streamTaskLog(logFile, offset).get().openStream()));
} }
private static class TestTaskReport implements TaskReport
{
static final String KEY = "testReport";
static final Map<String, Object> PAYLOAD = ImmutableMap.of("foo", "bar");
private final String taskId;
public TestTaskReport(String taskId)
{
this.taskId = taskId;
}
@Override
@JsonProperty
public String getTaskId()
{
return taskId;
}
@Override
public String getReportKey()
{
return KEY;
}
@Override
@JsonProperty
public Object getPayload()
{
return PAYLOAD;
}
}
} }

View File

@ -73,9 +73,17 @@ public class OverlordBlinkLeadershipTest
public void testOverlordBlinkLeadership() public void testOverlordBlinkLeadership()
{ {
try { try {
RemoteTaskRunner remoteTaskRunner1 = rtrUtils.makeRemoteTaskRunner(remoteTaskRunnerConfig, resourceManagement); RemoteTaskRunner remoteTaskRunner1 = rtrUtils.makeRemoteTaskRunner(
remoteTaskRunnerConfig,
resourceManagement,
null
);
remoteTaskRunner1.stop(); remoteTaskRunner1.stop();
RemoteTaskRunner remoteTaskRunner2 = rtrUtils.makeRemoteTaskRunner(remoteTaskRunnerConfig, resourceManagement); RemoteTaskRunner remoteTaskRunner2 = rtrUtils.makeRemoteTaskRunner(
remoteTaskRunnerConfig,
resourceManagement,
null
);
remoteTaskRunner2.stop(); remoteTaskRunner2.stop();
} }
catch (Exception e) { catch (Exception e) {

View File

@ -69,7 +69,8 @@ public class RemoteTaskRunnerRunPendingTasksConcurrencyTest
{ {
return 2; return 2;
} }
} },
null
); );
int numTasks = 6; int numTasks = 6;

View File

@ -22,14 +22,18 @@ package org.apache.druid.indexing.overlord;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Function; import com.google.common.base.Function;
import com.google.common.base.Joiner; import com.google.common.base.Joiner;
import com.google.common.base.Optional;
import com.google.common.base.Predicate; import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables; import com.google.common.collect.Iterables;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFuture;
import org.apache.curator.framework.CuratorFramework; import org.apache.curator.framework.CuratorFramework;
import org.apache.curator.framework.recipes.cache.PathChildrenCache; import org.apache.curator.framework.recipes.cache.PathChildrenCache;
import org.apache.druid.indexer.TaskLocation;
import org.apache.druid.indexer.TaskState; import org.apache.druid.indexer.TaskState;
import org.apache.druid.indexer.TaskStatus; import org.apache.druid.indexer.TaskStatus;
import org.apache.druid.indexing.common.IndexingServiceCondition; import org.apache.druid.indexing.common.IndexingServiceCondition;
@ -46,6 +50,8 @@ import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.java.util.emitter.EmittingLogger; import org.apache.druid.java.util.emitter.EmittingLogger;
import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.java.util.http.client.HttpClient;
import org.apache.druid.java.util.http.client.Request;
import org.apache.druid.server.metrics.NoopServiceEmitter; import org.apache.druid.server.metrics.NoopServiceEmitter;
import org.apache.druid.testing.DeadlockDetectingTimeout; import org.apache.druid.testing.DeadlockDetectingTimeout;
import org.easymock.Capture; import org.easymock.Capture;
@ -61,6 +67,9 @@ import org.junit.rules.TestWatcher;
import org.junit.runner.Description; import org.junit.runner.Description;
import org.mockito.Mockito; import org.mockito.Mockito;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Map; import java.util.Map;
@ -73,13 +82,17 @@ public class RemoteTaskRunnerTest
private static final Logger LOG = new Logger(RemoteTaskRunnerTest.class); private static final Logger LOG = new Logger(RemoteTaskRunnerTest.class);
private static final Joiner JOINER = RemoteTaskRunnerTestUtils.JOINER; private static final Joiner JOINER = RemoteTaskRunnerTestUtils.JOINER;
private static final String WORKER_HOST = "worker"; private static final String WORKER_HOST = "worker";
private static final String ANNOUCEMENTS_PATH = JOINER.join(RemoteTaskRunnerTestUtils.ANNOUNCEMENTS_PATH, WORKER_HOST); private static final String ANNOUCEMENTS_PATH = JOINER.join(
RemoteTaskRunnerTestUtils.ANNOUNCEMENTS_PATH,
WORKER_HOST
);
private static final String STATUS_PATH = JOINER.join(RemoteTaskRunnerTestUtils.STATUS_PATH, WORKER_HOST); private static final String STATUS_PATH = JOINER.join(RemoteTaskRunnerTestUtils.STATUS_PATH, WORKER_HOST);
// higher timeout to reduce flakiness on CI pipeline // higher timeout to reduce flakiness on CI pipeline
private static final Period TIMEOUT_PERIOD = Period.millis(30000); private static final Period TIMEOUT_PERIOD = Period.millis(30000);
private RemoteTaskRunner remoteTaskRunner; private RemoteTaskRunner remoteTaskRunner;
private HttpClient httpClient;
private RemoteTaskRunnerTestUtils rtrTestUtils = new RemoteTaskRunnerTestUtils(); private RemoteTaskRunnerTestUtils rtrTestUtils = new RemoteTaskRunnerTestUtils();
private ObjectMapper jsonMapper; private ObjectMapper jsonMapper;
private CuratorFramework cf; private CuratorFramework cf;
@ -401,7 +414,8 @@ public class RemoteTaskRunnerTest
new TaskResource("first", 1), new TaskResource("first", 1),
"foo", "foo",
TaskStatus.running("first"), TaskStatus.running("first"),
jsonMapper); jsonMapper
);
remoteTaskRunner.run(task1); remoteTaskRunner.run(task1);
Assert.assertTrue(taskAnnounced(task1.getId())); Assert.assertTrue(taskAnnounced(task1.getId()));
mockWorkerRunningTask(task1); mockWorkerRunningTask(task1);
@ -411,7 +425,8 @@ public class RemoteTaskRunnerTest
new TaskResource("task", 2), new TaskResource("task", 2),
"foo", "foo",
TaskStatus.running("task"), TaskStatus.running("task"),
jsonMapper); jsonMapper
);
remoteTaskRunner.run(task); remoteTaskRunner.run(task);
TestRealtimeTask task2 = new TestRealtimeTask( TestRealtimeTask task2 = new TestRealtimeTask(
@ -419,7 +434,8 @@ public class RemoteTaskRunnerTest
new TaskResource("second", 2), new TaskResource("second", 2),
"foo", "foo",
TaskStatus.running("second"), TaskStatus.running("second"),
jsonMapper); jsonMapper
);
remoteTaskRunner.run(task2); remoteTaskRunner.run(task2);
Assert.assertTrue(taskAnnounced(task2.getId())); Assert.assertTrue(taskAnnounced(task2.getId()));
mockWorkerRunningTask(task2); mockWorkerRunningTask(task2);
@ -449,7 +465,8 @@ public class RemoteTaskRunnerTest
new TaskResource("testTask", 2), new TaskResource("testTask", 2),
"foo", "foo",
TaskStatus.success("testTask"), TaskStatus.success("testTask"),
jsonMapper); jsonMapper
);
remoteTaskRunner.run(task1); remoteTaskRunner.run(task1);
Assert.assertTrue(taskAnnounced(task1.getId())); Assert.assertTrue(taskAnnounced(task1.getId()));
mockWorkerRunningTask(task1); mockWorkerRunningTask(task1);
@ -590,7 +607,8 @@ public class RemoteTaskRunnerTest
private void makeRemoteTaskRunner(RemoteTaskRunnerConfig config) private void makeRemoteTaskRunner(RemoteTaskRunnerConfig config)
{ {
remoteTaskRunner = rtrTestUtils.makeRemoteTaskRunner(config); httpClient = EasyMock.createMock(HttpClient.class);
remoteTaskRunner = rtrTestUtils.makeRemoteTaskRunner(config, httpClient);
} }
private void makeWorker() throws Exception private void makeWorker() throws Exception
@ -1022,7 +1040,10 @@ public class RemoteTaskRunnerTest
mockWorkerCompleteFailedTask(task3); mockWorkerCompleteFailedTask(task3);
Assert.assertTrue(taskFuture3.get().isFailure()); Assert.assertTrue(taskFuture3.get().isFailure());
Assert.assertEquals(1, remoteTaskRunner.getBlackListedWorkers().size()); Assert.assertEquals(1, remoteTaskRunner.getBlackListedWorkers().size());
Assert.assertEquals(3, remoteTaskRunner.getBlacklistedTaskSlotCount().get(WorkerConfig.DEFAULT_CATEGORY).longValue()); Assert.assertEquals(
3,
remoteTaskRunner.getBlacklistedTaskSlotCount().get(WorkerConfig.DEFAULT_CATEGORY).longValue()
);
mockWorkerCompleteSuccessfulTask(task2); mockWorkerCompleteSuccessfulTask(task2);
Assert.assertTrue(taskFuture2.get().isSuccess()); Assert.assertTrue(taskFuture2.get().isSuccess());
@ -1046,7 +1067,8 @@ public class RemoteTaskRunnerTest
PathChildrenCache cache = new PathChildrenCache(cf, "/test", true); PathChildrenCache cache = new PathChildrenCache(cf, "/test", true);
testStartWithNoWorker(); testStartWithNoWorker();
cache.getListenable().addListener(remoteTaskRunner.getStatusListener(worker, new ZkWorker(worker, cache, jsonMapper), null)); cache.getListenable()
.addListener(remoteTaskRunner.getStatusListener(worker, new ZkWorker(worker, cache, jsonMapper), null));
cache.start(PathChildrenCache.StartMode.POST_INITIALIZED_EVENT); cache.start(PathChildrenCache.StartMode.POST_INITIALIZED_EVENT);
// Status listener will recieve event with null data // Status listener will recieve event with null data
@ -1062,4 +1084,56 @@ public class RemoteTaskRunnerTest
Assert.assertNull(alertDataMap.get("znode")); Assert.assertNull(alertDataMap.get("znode"));
// Status listener should successfully completes without throwing exception // Status listener should successfully completes without throwing exception
} }
@Test
public void testStreamTaskReportsUnknownTask() throws Exception
{
doSetup();
Assert.assertEquals(Optional.absent(), remoteTaskRunner.streamTaskReports("foo"));
}
@Test
public void testStreamTaskReportsKnownTask() throws Exception
{
doSetup();
final Capture<Request> capturedRequest = Capture.newInstance();
final String reportString = "my report!";
final ByteArrayInputStream reportResponse = new ByteArrayInputStream(StringUtils.toUtf8(reportString));
EasyMock.expect(httpClient.go(EasyMock.capture(capturedRequest), EasyMock.anyObject()))
.andReturn(Futures.immediateFuture(reportResponse));
EasyMock.replay(httpClient);
ListenableFuture<TaskStatus> result = remoteTaskRunner.run(task);
Assert.assertTrue(taskAnnounced(task.getId()));
mockWorkerRunningTask(task);
// Wait for the task to have a known location.
Assert.assertTrue(
TestUtils.conditionValid(
() ->
!remoteTaskRunner.getRunningTasks().isEmpty()
&& !Iterables.getOnlyElement(remoteTaskRunner.getRunningTasks())
.getLocation()
.equals(TaskLocation.unknown())
)
);
// Stream task reports from a running task.
final InputStream in = remoteTaskRunner.streamTaskReports(task.getId()).get().openStream();
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
ByteStreams.copy(in, baos);
Assert.assertEquals(reportString, StringUtils.fromUtf8(baos.toByteArray()));
// Stream task reports from a completed task.
mockWorkerCompleteSuccessfulTask(task);
Assert.assertTrue(workerCompletedTask(result));
Assert.assertEquals(Optional.absent(), remoteTaskRunner.streamTaskReports(task.getId()));
// Verify the HTTP request.
EasyMock.verify(httpClient);
Assert.assertEquals(
"http://dummy:9000/druid/worker/v1/chat/task%20id%20with%20spaces/liveReports",
capturedRequest.getValue().getUrl().toString()
);
}
} }

View File

@ -106,15 +106,16 @@ public class RemoteTaskRunnerTestUtils
testingCluster.stop(); testingCluster.stop();
} }
RemoteTaskRunner makeRemoteTaskRunner(RemoteTaskRunnerConfig config) RemoteTaskRunner makeRemoteTaskRunner(RemoteTaskRunnerConfig config, HttpClient httpClient)
{ {
NoopProvisioningStrategy<WorkerTaskRunner> resourceManagement = new NoopProvisioningStrategy<>(); NoopProvisioningStrategy<WorkerTaskRunner> resourceManagement = new NoopProvisioningStrategy<>();
return makeRemoteTaskRunner(config, resourceManagement); return makeRemoteTaskRunner(config, resourceManagement, httpClient);
} }
public RemoteTaskRunner makeRemoteTaskRunner( public RemoteTaskRunner makeRemoteTaskRunner(
RemoteTaskRunnerConfig config, RemoteTaskRunnerConfig config,
ProvisioningStrategy<WorkerTaskRunner> provisioningStrategy ProvisioningStrategy<WorkerTaskRunner> provisioningStrategy,
HttpClient httpClient
) )
{ {
RemoteTaskRunner remoteTaskRunner = new TestableRemoteTaskRunner( RemoteTaskRunner remoteTaskRunner = new TestableRemoteTaskRunner(
@ -132,7 +133,7 @@ public class RemoteTaskRunnerTestUtils
), ),
cf, cf,
new PathChildrenCacheFactory.Builder(), new PathChildrenCacheFactory.Builder(),
null, httpClient,
DSuppliers.of(new AtomicReference<>(DefaultWorkerBehaviorConfig.defaultConfig())), DSuppliers.of(new AtomicReference<>(DefaultWorkerBehaviorConfig.defaultConfig())),
provisioningStrategy provisioningStrategy
); );