Add code to wait for segments generated to be loaded on historicals (#14322)

Currently, after an MSQ query, the web console is responsible for waiting for the segments to load. It does so by checking if there are any segments loading into the datasource ingested into, which can cause some issues, like in cases where the segments would never be loaded, or would end up waiting for other ingests as well.

This PR shifts this responsibility to the controller, which would have the list of segments created.
This commit is contained in:
Adarsh Sanjeev 2023-09-06 10:35:57 +05:30 committed by GitHub
parent 706b57c0b2
commit 959148ad37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1070 additions and 74 deletions

View File

@ -288,7 +288,19 @@ The response shows an example report for a query.
"startTime": "2022-09-14T22:12:09.266Z",
"durationMs": 28227,
"pendingTasks": 0,
"runningTasks": 2
"runningTasks": 2,
"segmentLoadStatus": {
"state": "SUCCESS",
"dataSource": "kttm_simple",
"startTime": "2022-09-14T23:12:09.266Z",
"duration": 15,
"totalSegments": 1,
"usedSegments": 1,
"precachedSegments": 0,
"onDemandSegments": 0,
"pendingSegments": 0,
"unknownSegments": 0
}
},
"stages": [
{
@ -593,6 +605,16 @@ The following table describes the response fields when you retrieve a report for
| `multiStageQuery.payload.status.durationMs` | Milliseconds elapsed after the query has started running. -1 denotes that the query hasn't started running yet. |
| `multiStageQuery.payload.status.pendingTasks` | Number of tasks that are not fully started. -1 denotes that the number is currently unknown. |
| `multiStageQuery.payload.status.runningTasks` | Number of currently running tasks. Should be at least 1 since the controller is included. |
| `multiStageQuery.payload.status.segmentLoadStatus` | Segment loading container. Only present after the segments have been published. |
| `multiStageQuery.payload.status.segmentLoadStatus.state` | Either INIT, WAITING, SUCCESS, FAILED or TIMED_OUT. |
| `multiStageQuery.payload.status.segmentLoadStatus.startTime` | Time since which the controller has been waiting for the segments to finish loading. |
| `multiStageQuery.payload.status.segmentLoadStatus.duration` | The duration in milliseconds that the controller has been waiting for the segments to load. |
| `multiStageQuery.payload.status.segmentLoadStatus.totalSegments` | The total number of segments generated by the job. This includes tombstone segments (if any). |
| `multiStageQuery.payload.status.segmentLoadStatus.usedSegments` | The number of segments which are marked as used based on the load rules. Unused segments can be cleaned up at any time. |
| `multiStageQuery.payload.status.segmentLoadStatus.precachedSegments` | The number of segments which are marked as precached and served by historicals, as per the load rules. |
| `multiStageQuery.payload.status.segmentLoadStatus.onDemandSegments` | The number of segments which are not loaded on any historical, as per the load rules. |
| `multiStageQuery.payload.status.segmentLoadStatus.pendingSegments` | The number of segments remaining to be loaded. |
| `multiStageQuery.payload.status.segmentLoadStatus.unknownSegments` | The number of segments whose status is unknown. |
| `multiStageQuery.payload.status.errorReport` | Error object. Only present if there was an error. |
| `multiStageQuery.payload.status.errorReport.taskId` | The task that reported the error, if known. May be a controller task or a worker task. |
| `multiStageQuery.payload.status.errorReport.host` | The hostname and port of the task that reported the error, if known. |

View File

@ -44,6 +44,7 @@ import org.apache.druid.data.input.StringTuple;
import org.apache.druid.data.input.impl.DimensionSchema;
import org.apache.druid.data.input.impl.DimensionsSpec;
import org.apache.druid.data.input.impl.TimestampSpec;
import org.apache.druid.discovery.BrokerClient;
import org.apache.druid.frame.allocation.ArenaMemoryAllocator;
import org.apache.druid.frame.channel.FrameChannelSequence;
import org.apache.druid.frame.key.ClusterBy;
@ -63,6 +64,7 @@ import org.apache.druid.indexing.common.TaskLock;
import org.apache.druid.indexing.common.TaskLockType;
import org.apache.druid.indexing.common.TaskReport;
import org.apache.druid.indexing.common.actions.LockListAction;
import org.apache.druid.indexing.common.actions.LockReleaseAction;
import org.apache.druid.indexing.common.actions.MarkSegmentsAsUnusedAction;
import org.apache.druid.indexing.common.actions.SegmentAllocateAction;
import org.apache.druid.indexing.common.actions.SegmentTransactionalInsertAction;
@ -292,6 +294,7 @@ public class ControllerImpl implements Controller
private WorkerMemoryParameters workerMemoryParameters;
private boolean isDurableStorageEnabled;
private boolean isFaultToleranceEnabled;
private volatile SegmentLoadWaiter segmentLoadWaiter;
public ControllerImpl(
final MSQControllerTask task,
@ -437,6 +440,45 @@ public class ControllerImpl implements Controller
}
}
if (queryKernel != null && queryKernel.isSuccess()) {
// If successful, encourage the tasks to exit successfully.
postFinishToAllTasks();
workerTaskLauncher.stop(false);
} else {
// If not successful, cancel running tasks.
if (workerTaskLauncher != null) {
workerTaskLauncher.stop(true);
}
}
// Wait for worker tasks to exit. Ignore their return status. At this point, we've done everything we need to do,
// so we don't care about the task exit status.
if (workerTaskRunnerFuture != null) {
try {
workerTaskRunnerFuture.get();
}
catch (Exception ignored) {
// Suppress.
}
}
try {
releaseTaskLocks();
cleanUpDurableStorageIfNeeded();
if (queryKernel != null && queryKernel.isSuccess()) {
if (segmentLoadWaiter != null) {
// If successful and there are segments created, segmentLoadWaiter should wait for them to become available.
segmentLoadWaiter.waitForSegmentsToLoad();
}
}
}
catch (Exception e) {
log.warn(e, "Exception thrown during cleanup. Ignoring it and writing task report.");
}
try {
// Write report even if something went wrong.
final MSQStagesReport stagesReport;
@ -488,7 +530,8 @@ public class ControllerImpl implements Controller
workerWarnings,
queryStartTime,
new Interval(queryStartTime, DateTimes.nowUtc()).toDurationMillis(),
workerTaskLauncher
workerTaskLauncher,
segmentLoadWaiter
),
stagesReport,
countersSnapshot,
@ -504,30 +547,6 @@ public class ControllerImpl implements Controller
log.warn(e, "Error encountered while writing task report. Skipping.");
}
if (queryKernel != null && queryKernel.isSuccess()) {
// If successful, encourage the tasks to exit successfully.
postFinishToAllTasks();
workerTaskLauncher.stop(false);
} else {
// If not successful, cancel running tasks.
if (workerTaskLauncher != null) {
workerTaskLauncher.stop(true);
}
}
// Wait for worker tasks to exit. Ignore their return status. At this point, we've done everything we need to do,
// so we don't care about the task exit status.
if (workerTaskRunnerFuture != null) {
try {
workerTaskRunnerFuture.get();
}
catch (Exception ignored) {
// Suppress.
}
}
cleanUpDurableStorageIfNeeded();
if (taskStateForReport == TaskState.SUCCESS) {
return TaskStatus.success(id());
} else {
@ -536,6 +555,23 @@ public class ControllerImpl implements Controller
}
}
/**
* Releases the locks obtained by the task.
*/
private void releaseTaskLocks() throws IOException
{
final List<TaskLock> locks;
try {
locks = context.taskActionClient().submit(new LockListAction());
for (final TaskLock lock : locks) {
context.taskActionClient().submit(new LockReleaseAction(lock.getInterval()));
}
}
catch (IOException e) {
throw new IOException("Failed to release locks", e);
}
}
/**
* Adds some logic to {@link #kernelManipulationQueue}, where it will, in due time, be executed by the main
* controller loop in {@link RunQueryUntilDone#run()}.
@ -875,7 +911,8 @@ public class ControllerImpl implements Controller
workerWarnings,
queryStartTime,
queryStartTime == null ? -1L : new Interval(queryStartTime, DateTimes.nowUtc()).toDurationMillis(),
workerTaskLauncher
workerTaskLauncher,
segmentLoadWaiter
),
makeStageReport(
queryDef,
@ -1316,17 +1353,36 @@ public class ControllerImpl implements Controller
if (segmentsWithTombstones.isEmpty()) {
// Nothing to publish, only drop. We already validated that the intervalsToDrop do not have any
// partially-overlapping segments, so it's safe to drop them as intervals instead of as specific segments.
// This should not need a segment load wait as segments are marked as unused immediately.
for (final Interval interval : intervalsToDrop) {
context.taskActionClient()
.submit(new MarkSegmentsAsUnusedAction(task.getDataSource(), interval));
}
} else {
Set<String> versionsToAwait = segmentsWithTombstones.stream().map(DataSegment::getVersion).collect(Collectors.toSet());
segmentLoadWaiter = new SegmentLoadWaiter(
context.injector().getInstance(BrokerClient.class),
context.jsonMapper(),
task.getDataSource(),
versionsToAwait,
segmentsWithTombstones.size(),
true
);
performSegmentPublish(
context.taskActionClient(),
SegmentTransactionalInsertAction.overwriteAction(null, segmentsWithTombstones)
);
}
} else if (!segments.isEmpty()) {
Set<String> versionsToAwait = segments.stream().map(DataSegment::getVersion).collect(Collectors.toSet());
segmentLoadWaiter = new SegmentLoadWaiter(
context.injector().getInstance(BrokerClient.class),
context.jsonMapper(),
task.getDataSource(),
versionsToAwait,
segments.size(),
true
);
// Append mode.
performSegmentPublish(
context.taskActionClient(),
@ -2072,7 +2128,8 @@ public class ControllerImpl implements Controller
final Queue<MSQErrorReport> errorReports,
@Nullable final DateTime queryStartTime,
final long queryDuration,
MSQWorkerTaskLauncher taskLauncher
MSQWorkerTaskLauncher taskLauncher,
final SegmentLoadWaiter segmentLoadWaiter
)
{
int pendingTasks = -1;
@ -2083,6 +2140,9 @@ public class ControllerImpl implements Controller
pendingTasks = workerTaskCount.getPendingWorkerCount();
runningTasks = workerTaskCount.getRunningWorkerCount() + 1; // To account for controller.
}
SegmentLoadWaiter.SegmentLoadWaiterStatus status = segmentLoadWaiter == null ? null : segmentLoadWaiter.status();
return new MSQStatusReport(
taskState,
errorReport,
@ -2090,7 +2150,8 @@ public class ControllerImpl implements Controller
queryStartTime,
queryDuration,
pendingTasks,
runningTasks
runningTasks,
status
);
}
@ -2259,6 +2320,7 @@ public class ControllerImpl implements Controller
throwKernelExceptionIfNotUnknown();
}
updateLiveReportMaps();
cleanUpEffectivelyFinishedStages();
return Pair.of(queryKernel, workerTaskLauncherFuture);
}

View File

@ -0,0 +1,422 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.exec;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.druid.discovery.BrokerClient;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.java.util.http.client.Request;
import org.apache.druid.sql.http.ResultFormat;
import org.apache.druid.sql.http.SqlQuery;
import org.jboss.netty.handler.codec.http.HttpMethod;
import org.joda.time.DateTime;
import org.joda.time.Interval;
import javax.annotation.Nullable;
import javax.ws.rs.core.MediaType;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
/**
* Class that periodically checks with the broker if all the segments generated are loaded by querying the sys table
* and blocks till it is complete. This will account for and not wait for segments that would never be loaded due to
* load rules. Should only be called if the query generates new segments or tombstones.
* <br>
* If an exception is thrown during operation, this will simply log the exception and exit without failing the task,
* since the segments have already been published successfully, and should be loaded eventually.
* <br>
* If the segments are not loaded within {@link #TIMEOUT_DURATION_MILLIS} milliseconds, this logs a warning and exits
* for the same reason.
*/
public class SegmentLoadWaiter
{
private static final Logger log = new Logger(SegmentLoadWaiter.class);
private static final long SLEEP_DURATION_MILLIS = TimeUnit.SECONDS.toMillis(5);
private static final long TIMEOUT_DURATION_MILLIS = TimeUnit.MINUTES.toMillis(10);
/**
* The query sent to the broker. This query uses replication_factor to determine how many copies of a segment has to be
* loaded as per the load rules.
* - If a segment is not used, the broker will not have any information about it, hence, a COUNT(*) should return the used count only.
* - If replication_factor is more than 0, the segment will be loaded on historicals and needs to be waited for.
* - If replication_factor is 0, that means that the segment will never be loaded on a historical and does not need to
* be waited for.
* - If replication_factor is -1, the replication factor is not known currently and will become known after a load rule
* evaluation.
* <br>
* See https://github.com/apache/druid/pull/14403 for more details about replication_factor
*/
private static final String LOAD_QUERY = "SELECT COUNT(*) AS usedSegments,\n"
+ "COUNT(*) FILTER (WHERE is_published = 1 AND replication_factor > 0) AS precachedSegments,\n"
+ "COUNT(*) FILTER (WHERE is_published = 1 AND replication_factor = 0) AS onDemandSegments,\n"
+ "COUNT(*) FILTER (WHERE is_available = 0 AND is_published = 1 AND replication_factor != 0) AS pendingSegments,\n"
+ "COUNT(*) FILTER (WHERE replication_factor = -1) AS unknownSegments\n"
+ "FROM sys.segments\n"
+ "WHERE datasource = '%s' AND is_overshadowed = 0 AND version = '%s'";
private final BrokerClient brokerClient;
private final ObjectMapper objectMapper;
// Map of version vs latest load status.
private final Map<String, VersionLoadStatus> versionToLoadStatusMap;
private final String datasource;
private final Set<String> versionsToAwait;
private final int totalSegmentsGenerated;
private final boolean doWait;
private final AtomicReference<SegmentLoadWaiterStatus> status;
public SegmentLoadWaiter(
BrokerClient brokerClient,
ObjectMapper objectMapper,
String datasource,
Set<String> versionsToAwait,
int totalSegmentsGenerated,
boolean doWait
)
{
this.brokerClient = brokerClient;
this.objectMapper = objectMapper;
this.datasource = datasource;
this.versionsToAwait = new TreeSet<>(versionsToAwait);
this.versionToLoadStatusMap = new HashMap<>();
this.totalSegmentsGenerated = totalSegmentsGenerated;
this.status = new AtomicReference<>(new SegmentLoadWaiterStatus(State.INIT, null, 0, totalSegmentsGenerated, 0, 0, 0, 0, totalSegmentsGenerated));
this.doWait = doWait;
}
/**
* Uses broker client to check if all segments created by the ingestion have been loaded and updates the {@link #status)}
* periodically.
* <br>
* If an exception is thrown during operation, this will log the exception and return without failing the task,
* since the segments have already been published successfully, and should be loaded eventually.
* <br>
* Only expected to be called from the main controller thread.
*/
public void waitForSegmentsToLoad()
{
DateTime startTime = DateTimes.nowUtc();
boolean hasAnySegmentBeenLoaded = false;
try {
while (!versionsToAwait.isEmpty()) {
// Check the timeout and exit if exceeded.
long runningMillis = new Interval(startTime, DateTimes.nowUtc()).toDurationMillis();
if (runningMillis > TIMEOUT_DURATION_MILLIS) {
log.warn("Runtime [%s] exceeded timeout [%s] while waiting for segments to load. Exiting.", runningMillis, TIMEOUT_DURATION_MILLIS);
updateStatus(State.TIMED_OUT, startTime);
return;
}
Iterator<String> iterator = versionsToAwait.iterator();
// Query the broker for all pending versions
while (iterator.hasNext()) {
String version = iterator.next();
// Fetch the load status for this version from the broker
VersionLoadStatus loadStatus = fetchLoadStatusForVersion(version);
versionToLoadStatusMap.put(version, loadStatus);
hasAnySegmentBeenLoaded = hasAnySegmentBeenLoaded || loadStatus.getUsedSegments() > 0;
// If loading is done for this stage, remove it from future loops.
if (hasAnySegmentBeenLoaded && loadStatus.isLoadingComplete()) {
iterator.remove();
}
}
if (!versionsToAwait.isEmpty()) {
// Update the status.
updateStatus(State.WAITING, startTime);
// Sleep for a while before retrying.
waitIfNeeded(SLEEP_DURATION_MILLIS);
}
}
}
catch (Exception e) {
log.warn(e, "Exception occurred while waiting for segments to load. Exiting.");
// Update the status and return.
updateStatus(State.FAILED, startTime);
return;
}
// Update the status.
updateStatus(State.SUCCESS, startTime);
}
private void waitIfNeeded(long waitTimeMillis) throws Exception
{
if (doWait) {
Thread.sleep(waitTimeMillis);
}
}
/**
* Updates the {@link #status} with the latest details based on {@link #versionToLoadStatusMap}
*/
private void updateStatus(State state, DateTime startTime)
{
int pendingSegmentCount = 0, usedSegmentsCount = 0, precachedSegmentCount = 0, onDemandSegmentCount = 0, unknownSegmentCount = 0;
for (Map.Entry<String, VersionLoadStatus> entry : versionToLoadStatusMap.entrySet()) {
usedSegmentsCount += entry.getValue().getUsedSegments();
precachedSegmentCount += entry.getValue().getPrecachedSegments();
onDemandSegmentCount += entry.getValue().getOnDemandSegments();
unknownSegmentCount += entry.getValue().getUnknownSegments();
pendingSegmentCount += entry.getValue().getPendingSegments();
}
long runningMillis = new Interval(startTime, DateTimes.nowUtc()).toDurationMillis();
status.set(
new SegmentLoadWaiterStatus(
state,
startTime,
runningMillis,
totalSegmentsGenerated,
usedSegmentsCount,
precachedSegmentCount,
onDemandSegmentCount,
pendingSegmentCount,
unknownSegmentCount
)
);
}
/**
* Uses {@link #brokerClient} to fetch latest load status for a given version. Converts the response into a
* {@link VersionLoadStatus} and returns it.
*/
private VersionLoadStatus fetchLoadStatusForVersion(String version) throws Exception
{
Request request = brokerClient.makeRequest(HttpMethod.POST, "/druid/v2/sql/");
SqlQuery sqlQuery = new SqlQuery(StringUtils.format(LOAD_QUERY, datasource, version),
ResultFormat.OBJECTLINES,
false, false, false, null, null);
request.setContent(MediaType.APPLICATION_JSON, objectMapper.writeValueAsBytes(sqlQuery));
String response = brokerClient.sendQuery(request);
if (response.trim().isEmpty()) {
// If no segments are returned for a version, all segments have been dropped by a drop rule.
return new VersionLoadStatus(0, 0, 0, 0, 0);
} else {
return objectMapper.readValue(response, VersionLoadStatus.class);
}
}
/**
* Returns the current status of the load.
*/
public SegmentLoadWaiterStatus status()
{
return status.get();
}
public static class SegmentLoadWaiterStatus
{
private final State state;
private final DateTime startTime;
private final long duration;
private final int totalSegments;
private final int usedSegments;
private final int precachedSegments;
private final int onDemandSegments;
private final int pendingSegments;
private final int unknownSegments;
@JsonCreator
public SegmentLoadWaiterStatus(
@JsonProperty("state") SegmentLoadWaiter.State state,
@JsonProperty("startTime") @Nullable DateTime startTime,
@JsonProperty("duration") long duration,
@JsonProperty("totalSegments") int totalSegments,
@JsonProperty("usedSegments") int usedSegments,
@JsonProperty("precachedSegments") int precachedSegments,
@JsonProperty("onDemandSegments") int onDemandSegments,
@JsonProperty("pendingSegments") int pendingSegments,
@JsonProperty("unknownSegments") int unknownSegments
)
{
this.state = state;
this.startTime = startTime;
this.duration = duration;
this.totalSegments = totalSegments;
this.usedSegments = usedSegments;
this.precachedSegments = precachedSegments;
this.onDemandSegments = onDemandSegments;
this.pendingSegments = pendingSegments;
this.unknownSegments = unknownSegments;
}
@JsonProperty
public SegmentLoadWaiter.State getState()
{
return state;
}
@Nullable
@JsonProperty
@JsonInclude(JsonInclude.Include.NON_NULL)
public DateTime getStartTime()
{
return startTime;
}
@JsonProperty
public long getDuration()
{
return duration;
}
@JsonProperty
public long getTotalSegments()
{
return totalSegments;
}
@JsonProperty
public int getUsedSegments()
{
return usedSegments;
}
@JsonProperty
public int getPrecachedSegments()
{
return precachedSegments;
}
@JsonProperty
public int getOnDemandSegments()
{
return onDemandSegments;
}
@JsonProperty
public int getPendingSegments()
{
return pendingSegments;
}
@JsonProperty
public int getUnknownSegments()
{
return unknownSegments;
}
}
public enum State
{
/**
* Initial state after being initialised with the segment versions and before #waitForSegmentsToLoad has been called.
*/
INIT,
/**
* All segments that need to be loaded have not yet been loaded. The load status is perodically being queried from
* the broker.
*/
WAITING,
/**
* All segments which need to be loaded have been loaded, and the SegmentLoadWaiter exited successfully.
*/
SUCCESS,
/**
* An exception occurred while checking load status. The SegmentLoadWaiter exited without failing the task.
*/
FAILED,
/**
* The time spent waiting for segments to load exceeded org.apache.druid.msq.exec.SegmentLoadWaiter#TIMEOUT_DURATION_MILLIS.
* The SegmentLoadWaiter exited without failing the task.
*/
TIMED_OUT
}
public static class VersionLoadStatus
{
private final int usedSegments;
private final int precachedSegments;
private final int onDemandSegments;
private final int pendingSegments;
private final int unknownSegments;
@JsonCreator
public VersionLoadStatus(
@JsonProperty("usedSegments") int usedSegments,
@JsonProperty("precachedSegments") int precachedSegments,
@JsonProperty("onDemandSegments") int onDemandSegments,
@JsonProperty("pendingSegments") int pendingSegments,
@JsonProperty("unknownSegments") int unknownSegments
)
{
this.usedSegments = usedSegments;
this.precachedSegments = precachedSegments;
this.onDemandSegments = onDemandSegments;
this.pendingSegments = pendingSegments;
this.unknownSegments = unknownSegments;
}
@JsonProperty
public int getUsedSegments()
{
return usedSegments;
}
@JsonProperty
public int getPrecachedSegments()
{
return precachedSegments;
}
@JsonProperty
public int getOnDemandSegments()
{
return onDemandSegments;
}
@JsonProperty
public int getPendingSegments()
{
return pendingSegments;
}
@JsonProperty
public int getUnknownSegments()
{
return unknownSegments;
}
@JsonIgnore
public boolean isLoadingComplete()
{
return pendingSegments == 0 && (usedSegments == precachedSegments + onDemandSegments);
}
}
}

View File

@ -182,7 +182,19 @@ public class MSQWorkerTaskLauncher
*/
public void stop(final boolean interrupt)
{
if (state.compareAndSet(State.STARTED, State.STOPPED)) {
if (state.compareAndSet(State.NEW, State.STOPPED)) {
state.set(State.STOPPED);
if (interrupt) {
cancelTasksOnStop.set(true);
}
synchronized (taskIds) {
// Wake up sleeping mainLoop.
taskIds.notifyAll();
}
exec.shutdown();
stopFuture.set(null);
} else if (state.compareAndSet(State.STARTED, State.STOPPED)) {
if (interrupt) {
cancelTasksOnStop.set(true);
}
@ -466,9 +478,13 @@ public class MSQWorkerTaskLauncher
public WorkerCount getWorkerTaskCount()
{
synchronized (taskIds) {
int runningTasks = fullyStartedTasks.size();
int pendingTasks = desiredTaskCount - runningTasks;
return new WorkerCount(runningTasks, pendingTasks);
if (stopFuture.isDone()) {
return new WorkerCount(0, 0);
} else {
int runningTasks = fullyStartedTasks.size();
int pendingTasks = desiredTaskCount - runningTasks;
return new WorkerCount(runningTasks, pendingTasks);
}
}
}

View File

@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import org.apache.druid.indexer.TaskState;
import org.apache.druid.msq.exec.SegmentLoadWaiter;
import org.apache.druid.msq.indexing.error.MSQErrorReport;
import org.joda.time.DateTime;
@ -50,6 +51,9 @@ public class MSQStatusReport
private final int runningTasks;
@Nullable
private final SegmentLoadWaiter.SegmentLoadWaiterStatus segmentLoadWaiterStatus;
@JsonCreator
public MSQStatusReport(
@JsonProperty("status") TaskState status,
@ -58,7 +62,8 @@ public class MSQStatusReport
@JsonProperty("startTime") @Nullable DateTime startTime,
@JsonProperty("durationMs") long durationMs,
@JsonProperty("pendingTasks") int pendingTasks,
@JsonProperty("runningTasks") int runningTasks
@JsonProperty("runningTasks") int runningTasks,
@JsonProperty("segmentLoadWaiterStatus") @Nullable SegmentLoadWaiter.SegmentLoadWaiterStatus segmentLoadWaiterStatus
)
{
this.status = Preconditions.checkNotNull(status, "status");
@ -68,6 +73,7 @@ public class MSQStatusReport
this.durationMs = durationMs;
this.pendingTasks = pendingTasks;
this.runningTasks = runningTasks;
this.segmentLoadWaiterStatus = segmentLoadWaiterStatus;
}
@JsonProperty
@ -117,6 +123,14 @@ public class MSQStatusReport
return durationMs;
}
@Nullable
@JsonProperty
@JsonInclude(JsonInclude.Include.NON_NULL)
public SegmentLoadWaiter.SegmentLoadWaiterStatus getSegmentLoadWaiterStatus()
{
return segmentLoadWaiterStatus;
}
@Override
public boolean equals(Object o)
{

View File

@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.exec;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.discovery.BrokerClient;
import org.apache.druid.java.util.http.client.Request;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
public class SegmentLoadWaiterTest
{
private static final String TEST_DATASOURCE = "testDatasource";
private SegmentLoadWaiter segmentLoadWaiter;
private BrokerClient brokerClient;
/**
* Single version created, loaded after 3 attempts
*/
@Test
public void testSingleVersionWaitsForLoadCorrectly() throws Exception
{
brokerClient = mock(BrokerClient.class);
doReturn(mock(Request.class)).when(brokerClient).makeRequest(any(), anyString());
doAnswer(new Answer<String>()
{
int timesInvoked = 0;
@Override
public String answer(InvocationOnMock invocation) throws Throwable
{
timesInvoked += 1;
SegmentLoadWaiter.VersionLoadStatus loadStatus = new SegmentLoadWaiter.VersionLoadStatus(5, timesInvoked, 0, 5 - timesInvoked, 0);
return new ObjectMapper().writeValueAsString(loadStatus);
}
}).when(brokerClient).sendQuery(any());
segmentLoadWaiter = new SegmentLoadWaiter(brokerClient, new ObjectMapper(), TEST_DATASOURCE, ImmutableSet.of("version1"), 5, false);
segmentLoadWaiter.waitForSegmentsToLoad();
verify(brokerClient, times(5)).sendQuery(any());
}
@Test
public void testMultipleVersionWaitsForLoadCorrectly() throws Exception
{
brokerClient = mock(BrokerClient.class);
doReturn(mock(Request.class)).when(brokerClient).makeRequest(any(), anyString());
doAnswer(new Answer<String>()
{
int timesInvoked = 0;
@Override
public String answer(InvocationOnMock invocation) throws Throwable
{
timesInvoked += 1;
SegmentLoadWaiter.VersionLoadStatus loadStatus = new SegmentLoadWaiter.VersionLoadStatus(5, timesInvoked, 0, 5 - timesInvoked, 0);
return new ObjectMapper().writeValueAsString(loadStatus);
}
}).when(brokerClient).sendQuery(any());
segmentLoadWaiter = new SegmentLoadWaiter(brokerClient, new ObjectMapper(), TEST_DATASOURCE, ImmutableSet.of("version1"), 5, false);
segmentLoadWaiter.waitForSegmentsToLoad();
verify(brokerClient, times(5)).sendQuery(any());
}
}

View File

@ -30,10 +30,12 @@ import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.indexer.TaskState;
import org.apache.druid.indexing.common.SingleFileTaskReportFileWriter;
import org.apache.druid.indexing.common.TaskReport;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.guava.Yielder;
import org.apache.druid.java.util.common.guava.Yielders;
import org.apache.druid.msq.counters.CounterSnapshotsTree;
import org.apache.druid.msq.exec.SegmentLoadWaiter;
import org.apache.druid.msq.guice.MSQIndexingModule;
import org.apache.druid.msq.indexing.error.MSQErrorReport;
import org.apache.druid.msq.indexing.error.TooManyColumnsFault;
@ -90,10 +92,22 @@ public class MSQTaskReportTest
new Object[]{"bar"}
);
SegmentLoadWaiter.SegmentLoadWaiterStatus status = new SegmentLoadWaiter.SegmentLoadWaiterStatus(
SegmentLoadWaiter.State.WAITING,
DateTimes.nowUtc(),
200L,
100,
80,
30,
50,
10,
0
);
final MSQTaskReport report = new MSQTaskReport(
TASK_ID,
new MSQTaskReportPayload(
new MSQStatusReport(TaskState.SUCCESS, null, new ArrayDeque<>(), null, 0, 1, 2),
new MSQStatusReport(TaskState.SUCCESS, null, new ArrayDeque<>(), null, 0, 1, 2, status),
MSQStagesReport.create(
QUERY_DEFINITION,
ImmutableMap.of(),
@ -142,11 +156,23 @@ public class MSQTaskReportTest
@Test
public void testSerdeErrorReport() throws Exception
{
SegmentLoadWaiter.SegmentLoadWaiterStatus status = new SegmentLoadWaiter.SegmentLoadWaiterStatus(
SegmentLoadWaiter.State.FAILED,
DateTimes.nowUtc(),
200L,
100,
80,
30,
50,
10,
0
);
final MSQErrorReport errorReport = MSQErrorReport.fromFault(TASK_ID, HOST, 0, new TooManyColumnsFault(10, 5));
final MSQTaskReport report = new MSQTaskReport(
TASK_ID,
new MSQTaskReportPayload(
new MSQStatusReport(TaskState.FAILED, errorReport, new ArrayDeque<>(), null, 0, 1, 2),
new MSQStatusReport(TaskState.FAILED, errorReport, new ArrayDeque<>(), null, 0, 1, 2, status),
MSQStagesReport.create(
QUERY_DEFINITION,
ImmutableMap.of(),
@ -179,10 +205,22 @@ public class MSQTaskReportTest
@Test
public void testWriteTaskReport() throws Exception
{
SegmentLoadWaiter.SegmentLoadWaiterStatus status = new SegmentLoadWaiter.SegmentLoadWaiterStatus(
SegmentLoadWaiter.State.SUCCESS,
DateTimes.nowUtc(),
200L,
100,
80,
30,
50,
10,
0
);
final MSQTaskReport report = new MSQTaskReport(
TASK_ID,
new MSQTaskReportPayload(
new MSQStatusReport(TaskState.SUCCESS, null, new ArrayDeque<>(), null, 0, 1, 2),
new MSQStatusReport(TaskState.SUCCESS, null, new ArrayDeque<>(), null, 0, 1, 2, status),
MSQStagesReport.create(
QUERY_DEFINITION,
ImmutableMap.of(),

View File

@ -240,7 +240,8 @@ public class SqlStatementResourceTest extends MSQTestBase
null,
0,
1,
2
2,
null
),
MSQStagesReport.create(
MSQTaskReportTest.QUERY_DEFINITION,
@ -305,7 +306,8 @@ public class SqlStatementResourceTest extends MSQTestBase
null,
0,
1,
2
2,
null
),
MSQStagesReport.create(
MSQTaskReportTest.QUERY_DEFINITION,

View File

@ -42,6 +42,7 @@ import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.impl.DimensionsSpec;
import org.apache.druid.data.input.impl.LongDimensionSchema;
import org.apache.druid.data.input.impl.StringDimensionSchema;
import org.apache.druid.discovery.BrokerClient;
import org.apache.druid.discovery.NodeRole;
import org.apache.druid.frame.channel.FrameChannelSequence;
import org.apache.druid.frame.testutil.FrameTestUtil;
@ -73,6 +74,7 @@ import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.java.util.common.guava.Yielder;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.java.util.http.client.Request;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.metadata.input.InputSourceModule;
import org.apache.druid.msq.counters.CounterNames;
@ -210,6 +212,10 @@ import static org.apache.druid.sql.calcite.util.CalciteTests.DATASOURCE1;
import static org.apache.druid.sql.calcite.util.CalciteTests.DATASOURCE2;
import static org.apache.druid.sql.calcite.util.TestDataBuilder.ROWS1;
import static org.apache.druid.sql.calcite.util.TestDataBuilder.ROWS2;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
/**
* Base test runner for running MSQ unit tests. It sets up multi stage query execution environment
@ -352,7 +358,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
// which depends on the object mapper that the injector will provide, once it
// is built, but has not yet been build while we build the SQL engine.
@Before
public void setUp2()
public void setUp2() throws Exception
{
groupByBuffers = TestGroupByBuffers.createDefault();
@ -380,6 +386,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
segmentManager = new MSQTestSegmentManager(segmentCacheManager, indexIO);
BrokerClient brokerClient = mock(BrokerClient.class);
List<Module> modules = ImmutableList.of(
binder -> {
DruidProcessingConfig druidProcessingConfig = new DruidProcessingConfig()
@ -467,7 +474,8 @@ public class MSQTestBase extends BaseCalciteQueryTest
),
new MSQExternalDataSourceModule(),
new LookylooModule(),
new SegmentWranglerModule()
new SegmentWranglerModule(),
binder -> binder.bind(BrokerClient.class).toInstance(brokerClient)
);
// adding node role injection to the modules, since CliPeon would also do that through run method
Injector injector = new CoreInjectorBuilder(new StartupInjectorBuilder().build(), ImmutableSet.of(NodeRole.PEON))
@ -477,6 +485,8 @@ public class MSQTestBase extends BaseCalciteQueryTest
objectMapper = setupObjectMapper(injector);
objectMapper.registerModules(sqlModule.getJacksonModules());
doReturn(mock(Request.class)).when(brokerClient).makeRequest(any(), anyString());
testTaskActionClient = Mockito.spy(new MSQTestTaskActionClient(objectMapper));
indexingServiceClient = new MSQTestOverlordServiceClient(
objectMapper,

View File

@ -0,0 +1,125 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.discovery;
import com.google.inject.Inject;
import org.apache.druid.error.DruidException;
import org.apache.druid.guice.annotations.EscalatedGlobal;
import org.apache.druid.java.util.common.IOE;
import org.apache.druid.java.util.common.RetryUtils;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.http.client.HttpClient;
import org.apache.druid.java.util.http.client.Request;
import org.apache.druid.java.util.http.client.response.StringFullResponseHandler;
import org.apache.druid.java.util.http.client.response.StringFullResponseHolder;
import org.jboss.netty.channel.ChannelException;
import org.jboss.netty.handler.codec.http.HttpMethod;
import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.ExecutionException;
/**
* This class facilitates interaction with Broker.
*/
public class BrokerClient
{
private static final int MAX_RETRIES = 5;
private final HttpClient brokerHttpClient;
private final DruidNodeDiscovery druidNodeDiscovery;
@Inject
public BrokerClient(
@EscalatedGlobal HttpClient brokerHttpClient,
DruidNodeDiscoveryProvider druidNodeDiscoveryProvider
)
{
this.brokerHttpClient = brokerHttpClient;
this.druidNodeDiscovery = druidNodeDiscoveryProvider.getForNodeRole(NodeRole.BROKER);
}
/**
* Creates and returns a {@link Request} after choosing a broker.
*/
public Request makeRequest(HttpMethod httpMethod, String urlPath) throws IOException
{
String host = ClientUtils.pickOneHost(druidNodeDiscovery);
if (host == null) {
throw DruidException.forPersona(DruidException.Persona.ADMIN)
.ofCategory(DruidException.Category.NOT_FOUND)
.build("A leader node could not be found for [%s] service. Check the logs to validate that service is healthy.", NodeRole.BROKER);
}
return new Request(httpMethod, new URL(StringUtils.format("%s%s", host, urlPath)));
}
public String sendQuery(final Request request) throws Exception
{
return RetryUtils.retry(
() -> {
Request newRequestUrl = getNewRequestUrl(request);
final StringFullResponseHolder fullResponseHolder = brokerHttpClient.go(newRequestUrl, new StringFullResponseHandler(StandardCharsets.UTF_8)).get();
HttpResponseStatus responseStatus = fullResponseHolder.getResponse().getStatus();
if (HttpResponseStatus.SERVICE_UNAVAILABLE.equals(responseStatus)
|| HttpResponseStatus.GATEWAY_TIMEOUT.equals(responseStatus)) {
throw DruidException.forPersona(DruidException.Persona.OPERATOR)
.ofCategory(DruidException.Category.RUNTIME_FAILURE)
.build("Request to broker failed due to failed response status: [%s]", responseStatus);
} else if (responseStatus.getCode() != HttpServletResponse.SC_OK) {
throw DruidException.forPersona(DruidException.Persona.OPERATOR)
.ofCategory(DruidException.Category.RUNTIME_FAILURE)
.build("Request to broker failed due to failed response code: [%s]", responseStatus.getCode());
}
return fullResponseHolder.getContent();
},
(throwable) -> {
if (throwable instanceof ExecutionException) {
return throwable.getCause() instanceof IOException || throwable.getCause() instanceof ChannelException;
}
return throwable instanceof IOE;
},
MAX_RETRIES
);
}
private Request getNewRequestUrl(Request oldRequest)
{
try {
return ClientUtils.withUrl(
oldRequest,
new URL(StringUtils.format("%s%s", ClientUtils.pickOneHost(druidNodeDiscovery), oldRequest.getUrl().getPath()))
);
}
catch (MalformedURLException e) {
// Not an IOException; this is our own fault.
throw DruidException.defensive(
"Failed to build url with path[%] and query string [%s].",
oldRequest.getUrl().getPath(),
oldRequest.getUrl().getQuery()
);
}
}
}

View File

@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.discovery;
import com.google.common.collect.Lists;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.http.client.Request;
import javax.annotation.Nullable;
import java.net.URL;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
/**
* Utils class for shared client methods
*/
public class ClientUtils
{
@Nullable
public static String pickOneHost(DruidNodeDiscovery druidNodeDiscovery)
{
Iterator<DiscoveryDruidNode> iter = druidNodeDiscovery.getAllNodes().iterator();
List<DiscoveryDruidNode> discoveryDruidNodeList = Lists.newArrayList(iter);
if (!discoveryDruidNodeList.isEmpty()) {
DiscoveryDruidNode node = discoveryDruidNodeList.get(ThreadLocalRandom.current().nextInt(discoveryDruidNodeList.size()));
return StringUtils.format(
"%s://%s",
node.getDruidNode().getServiceScheme(),
node.getDruidNode().getHostAndPortToUse()
);
}
return null;
}
public static Request withUrl(Request old, URL url)
{
Request req = new Request(old.getMethod(), url);
req.addHeaderValues(old.getHeaders());
if (old.hasContent()) {
req.setContent(old.getContent().copy());
}
return req;
}
}

View File

@ -39,12 +39,10 @@ import org.jboss.netty.channel.ChannelException;
import org.jboss.netty.handler.codec.http.HttpMethod;
import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import javax.annotation.Nullable;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
@ -202,7 +200,7 @@ public class DruidLeaderClient
redirectUrl.getPort()
));
request = withUrl(request, redirectUrl);
request = ClientUtils.withUrl(request, redirectUrl);
} else if (HttpResponseStatus.SERVICE_UNAVAILABLE.equals(responseStatus)
|| HttpResponseStatus.GATEWAY_TIMEOUT.equals(responseStatus)) {
log.warn(
@ -260,7 +258,7 @@ public class DruidLeaderClient
{
final String leader = currentKnownLeader.accumulateAndGet(
null,
(current, given) -> current == null || !cached ? pickOneHost() : current
(current, given) -> current == null || !cached ? ClientUtils.pickOneHost(druidNodeDiscovery) : current
);
if (leader == null) {
@ -274,43 +272,17 @@ public class DruidLeaderClient
}
}
@Nullable
private String pickOneHost()
{
Iterator<DiscoveryDruidNode> iter = druidNodeDiscovery.getAllNodes().iterator();
if (iter.hasNext()) {
DiscoveryDruidNode node = iter.next();
return StringUtils.format(
"%s://%s",
node.getDruidNode().getServiceScheme(),
node.getDruidNode().getHostAndPortToUse()
);
}
return null;
}
private Request withUrl(Request old, URL url)
{
Request req = new Request(old.getMethod(), url);
req.addHeaderValues(old.getHeaders());
if (old.hasContent()) {
req.setContent(old.getContent());
}
return req;
}
private Request getNewRequestUrlInvalidatingCache(Request oldRequest) throws IOException
{
try {
Request newRequest;
if (oldRequest.getUrl().getQuery() == null) {
newRequest = withUrl(
newRequest = ClientUtils.withUrl(
oldRequest,
new URL(StringUtils.format("%s%s", getCurrentKnownLeader(false), oldRequest.getUrl().getPath()))
);
} else {
newRequest = withUrl(
newRequest = ClientUtils.withUrl(
oldRequest,
new URL(StringUtils.format(
"%s%s?%s",

View File

@ -0,0 +1,154 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.discovery;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.inject.Injector;
import com.google.inject.Key;
import com.google.inject.name.Names;
import org.apache.druid.guice.GuiceInjectors;
import org.apache.druid.guice.Jerseys;
import org.apache.druid.guice.JsonConfigProvider;
import org.apache.druid.guice.LazySingleton;
import org.apache.druid.guice.LifecycleModule;
import org.apache.druid.guice.annotations.Self;
import org.apache.druid.initialization.Initialization;
import org.apache.druid.java.util.http.client.HttpClient;
import org.apache.druid.java.util.http.client.Request;
import org.apache.druid.server.DruidNode;
import org.apache.druid.server.initialization.BaseJettyTest;
import org.apache.druid.server.initialization.jetty.JettyServerInitializer;
import org.easymock.EasyMock;
import org.eclipse.jetty.server.Server;
import org.jboss.netty.handler.codec.http.HttpMethod;
import org.junit.Assert;
import org.junit.Test;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.nio.charset.StandardCharsets;
public class BrokerClientTest extends BaseJettyTest
{
private DiscoveryDruidNode discoveryDruidNode;
private HttpClient httpClient;
@Override
protected Injector setupInjector()
{
final DruidNode node = new DruidNode("test", "localhost", false, null, null, true, false);
discoveryDruidNode = new DiscoveryDruidNode(node, NodeRole.BROKER, ImmutableMap.of());
Injector injector = Initialization.makeInjectorWithModules(
GuiceInjectors.makeStartupInjector(), ImmutableList.of(
binder -> {
JsonConfigProvider.bindInstance(
binder,
Key.get(DruidNode.class, Self.class),
node
);
binder.bind(Integer.class).annotatedWith(Names.named("port")).toInstance(node.getPlaintextPort());
binder.bind(JettyServerInitializer.class).to(DruidLeaderClientTest.TestJettyServerInitializer.class).in(LazySingleton.class);
Jerseys.addResource(binder, SimpleResource.class);
LifecycleModule.register(binder, Server.class);
}
)
);
httpClient = injector.getInstance(BaseJettyTest.ClientHolder.class).getClient();
return injector;
}
@Test
public void testSimple() throws Exception
{
DruidNodeDiscovery druidNodeDiscovery = EasyMock.createMock(DruidNodeDiscovery.class);
EasyMock.expect(druidNodeDiscovery.getAllNodes()).andReturn(ImmutableList.of(discoveryDruidNode)).anyTimes();
DruidNodeDiscoveryProvider druidNodeDiscoveryProvider = EasyMock.createMock(DruidNodeDiscoveryProvider.class);
EasyMock.expect(druidNodeDiscoveryProvider.getForNodeRole(NodeRole.BROKER)).andReturn(druidNodeDiscovery);
EasyMock.replay(druidNodeDiscovery, druidNodeDiscoveryProvider);
BrokerClient brokerClient = new BrokerClient(
httpClient,
druidNodeDiscoveryProvider
);
Request request = brokerClient.makeRequest(HttpMethod.POST, "/simple/direct");
request.setContent("hello".getBytes(StandardCharsets.UTF_8));
Assert.assertEquals("hello", brokerClient.sendQuery(request));
}
@Test
public void testError() throws Exception
{
DruidNodeDiscovery druidNodeDiscovery = EasyMock.createMock(DruidNodeDiscovery.class);
EasyMock.expect(druidNodeDiscovery.getAllNodes()).andReturn(ImmutableList.of(discoveryDruidNode)).anyTimes();
DruidNodeDiscoveryProvider druidNodeDiscoveryProvider = EasyMock.createMock(DruidNodeDiscoveryProvider.class);
EasyMock.expect(druidNodeDiscoveryProvider.getForNodeRole(NodeRole.BROKER)).andReturn(druidNodeDiscovery);
EasyMock.replay(druidNodeDiscovery, druidNodeDiscoveryProvider);
BrokerClient brokerClient = new BrokerClient(
httpClient,
druidNodeDiscoveryProvider
);
Request request = brokerClient.makeRequest(HttpMethod.POST, "/simple/flakey");
request.setContent("hello".getBytes(StandardCharsets.UTF_8));
Assert.assertEquals("hello", brokerClient.sendQuery(request));
}
@Path("/simple")
public static class SimpleResource
{
private static int attempt = 0;
@POST
@Path("/direct")
@Produces(MediaType.APPLICATION_JSON)
public Response direct(String input)
{
if ("hello".equals(input)) {
return Response.ok("hello").build();
} else {
return Response.serverError().build();
}
}
@POST
@Path("/flakey")
@Produces(MediaType.APPLICATION_JSON)
public Response redirecting()
{
if (attempt > 2) {
return Response.ok("hello").build();
} else {
attempt += 1;
return Response.status(504).build();
}
}
}
}

View File

@ -281,7 +281,7 @@ public class DruidLeaderClientTest extends BaseJettyTest
Assert.assertEquals("http://localhost:1234/", druidLeaderClient.findCurrentLeader());
}
private static class TestJettyServerInitializer implements JettyServerInitializer
static class TestJettyServerInitializer implements JettyServerInitializer
{
@Override
public void initialize(Server server, Injector injector)

View File

@ -440,6 +440,7 @@ preemptible
prefetch
prefetched
prefetching
precached
prepend
prepended
prepending
@ -747,6 +748,7 @@ TooManyWorkers
NotEnoughMemory
WorkerFailed
WorkerRpcFailed
TIMED_OUT
# MSQ context parameters
maxNumTasks
taskAssignment