diff --git a/docs/api-reference/sql-ingestion-api.md b/docs/api-reference/sql-ingestion-api.md index 3109537c4a2..3daadfa5085 100644 --- a/docs/api-reference/sql-ingestion-api.md +++ b/docs/api-reference/sql-ingestion-api.md @@ -288,7 +288,19 @@ The response shows an example report for a query. "startTime": "2022-09-14T22:12:09.266Z", "durationMs": 28227, "pendingTasks": 0, - "runningTasks": 2 + "runningTasks": 2, + "segmentLoadStatus": { + "state": "SUCCESS", + "dataSource": "kttm_simple", + "startTime": "2022-09-14T23:12:09.266Z", + "duration": 15, + "totalSegments": 1, + "usedSegments": 1, + "precachedSegments": 0, + "onDemandSegments": 0, + "pendingSegments": 0, + "unknownSegments": 0 + } }, "stages": [ { @@ -593,6 +605,16 @@ The following table describes the response fields when you retrieve a report for | `multiStageQuery.payload.status.durationMs` | Milliseconds elapsed after the query has started running. -1 denotes that the query hasn't started running yet. | | `multiStageQuery.payload.status.pendingTasks` | Number of tasks that are not fully started. -1 denotes that the number is currently unknown. | | `multiStageQuery.payload.status.runningTasks` | Number of currently running tasks. Should be at least 1 since the controller is included. | +| `multiStageQuery.payload.status.segmentLoadStatus` | Segment loading container. Only present after the segments have been published. | +| `multiStageQuery.payload.status.segmentLoadStatus.state` | Either INIT, WAITING, SUCCESS, FAILED or TIMED_OUT. | +| `multiStageQuery.payload.status.segmentLoadStatus.startTime` | Time since which the controller has been waiting for the segments to finish loading. | +| `multiStageQuery.payload.status.segmentLoadStatus.duration` | The duration in milliseconds that the controller has been waiting for the segments to load. | +| `multiStageQuery.payload.status.segmentLoadStatus.totalSegments` | The total number of segments generated by the job. This includes tombstone segments (if any). | +| `multiStageQuery.payload.status.segmentLoadStatus.usedSegments` | The number of segments which are marked as used based on the load rules. Unused segments can be cleaned up at any time. | +| `multiStageQuery.payload.status.segmentLoadStatus.precachedSegments` | The number of segments which are marked as precached and served by historicals, as per the load rules. | +| `multiStageQuery.payload.status.segmentLoadStatus.onDemandSegments` | The number of segments which are not loaded on any historical, as per the load rules. | +| `multiStageQuery.payload.status.segmentLoadStatus.pendingSegments` | The number of segments remaining to be loaded. | +| `multiStageQuery.payload.status.segmentLoadStatus.unknownSegments` | The number of segments whose status is unknown. | | `multiStageQuery.payload.status.errorReport` | Error object. Only present if there was an error. | | `multiStageQuery.payload.status.errorReport.taskId` | The task that reported the error, if known. May be a controller task or a worker task. | | `multiStageQuery.payload.status.errorReport.host` | The hostname and port of the task that reported the error, if known. | diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index d883a587e9b..20033480f10 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -44,6 +44,7 @@ import org.apache.druid.data.input.StringTuple; import org.apache.druid.data.input.impl.DimensionSchema; import org.apache.druid.data.input.impl.DimensionsSpec; import org.apache.druid.data.input.impl.TimestampSpec; +import org.apache.druid.discovery.BrokerClient; import org.apache.druid.frame.allocation.ArenaMemoryAllocator; import org.apache.druid.frame.channel.FrameChannelSequence; import org.apache.druid.frame.key.ClusterBy; @@ -63,6 +64,7 @@ import org.apache.druid.indexing.common.TaskLock; import org.apache.druid.indexing.common.TaskLockType; import org.apache.druid.indexing.common.TaskReport; import org.apache.druid.indexing.common.actions.LockListAction; +import org.apache.druid.indexing.common.actions.LockReleaseAction; import org.apache.druid.indexing.common.actions.MarkSegmentsAsUnusedAction; import org.apache.druid.indexing.common.actions.SegmentAllocateAction; import org.apache.druid.indexing.common.actions.SegmentTransactionalInsertAction; @@ -292,6 +294,7 @@ public class ControllerImpl implements Controller private WorkerMemoryParameters workerMemoryParameters; private boolean isDurableStorageEnabled; private boolean isFaultToleranceEnabled; + private volatile SegmentLoadWaiter segmentLoadWaiter; public ControllerImpl( final MSQControllerTask task, @@ -437,6 +440,45 @@ public class ControllerImpl implements Controller } } + if (queryKernel != null && queryKernel.isSuccess()) { + // If successful, encourage the tasks to exit successfully. + postFinishToAllTasks(); + workerTaskLauncher.stop(false); + } else { + // If not successful, cancel running tasks. + if (workerTaskLauncher != null) { + workerTaskLauncher.stop(true); + } + } + + // Wait for worker tasks to exit. Ignore their return status. At this point, we've done everything we need to do, + // so we don't care about the task exit status. + if (workerTaskRunnerFuture != null) { + try { + workerTaskRunnerFuture.get(); + } + catch (Exception ignored) { + // Suppress. + } + } + + + try { + releaseTaskLocks(); + + cleanUpDurableStorageIfNeeded(); + + if (queryKernel != null && queryKernel.isSuccess()) { + if (segmentLoadWaiter != null) { + // If successful and there are segments created, segmentLoadWaiter should wait for them to become available. + segmentLoadWaiter.waitForSegmentsToLoad(); + } + } + } + catch (Exception e) { + log.warn(e, "Exception thrown during cleanup. Ignoring it and writing task report."); + } + try { // Write report even if something went wrong. final MSQStagesReport stagesReport; @@ -488,7 +530,8 @@ public class ControllerImpl implements Controller workerWarnings, queryStartTime, new Interval(queryStartTime, DateTimes.nowUtc()).toDurationMillis(), - workerTaskLauncher + workerTaskLauncher, + segmentLoadWaiter ), stagesReport, countersSnapshot, @@ -504,30 +547,6 @@ public class ControllerImpl implements Controller log.warn(e, "Error encountered while writing task report. Skipping."); } - if (queryKernel != null && queryKernel.isSuccess()) { - // If successful, encourage the tasks to exit successfully. - postFinishToAllTasks(); - workerTaskLauncher.stop(false); - } else { - // If not successful, cancel running tasks. - if (workerTaskLauncher != null) { - workerTaskLauncher.stop(true); - } - } - - // Wait for worker tasks to exit. Ignore their return status. At this point, we've done everything we need to do, - // so we don't care about the task exit status. - if (workerTaskRunnerFuture != null) { - try { - workerTaskRunnerFuture.get(); - } - catch (Exception ignored) { - // Suppress. - } - } - - cleanUpDurableStorageIfNeeded(); - if (taskStateForReport == TaskState.SUCCESS) { return TaskStatus.success(id()); } else { @@ -536,6 +555,23 @@ public class ControllerImpl implements Controller } } + /** + * Releases the locks obtained by the task. + */ + private void releaseTaskLocks() throws IOException + { + final List locks; + try { + locks = context.taskActionClient().submit(new LockListAction()); + for (final TaskLock lock : locks) { + context.taskActionClient().submit(new LockReleaseAction(lock.getInterval())); + } + } + catch (IOException e) { + throw new IOException("Failed to release locks", e); + } + } + /** * Adds some logic to {@link #kernelManipulationQueue}, where it will, in due time, be executed by the main * controller loop in {@link RunQueryUntilDone#run()}. @@ -875,7 +911,8 @@ public class ControllerImpl implements Controller workerWarnings, queryStartTime, queryStartTime == null ? -1L : new Interval(queryStartTime, DateTimes.nowUtc()).toDurationMillis(), - workerTaskLauncher + workerTaskLauncher, + segmentLoadWaiter ), makeStageReport( queryDef, @@ -1316,17 +1353,36 @@ public class ControllerImpl implements Controller if (segmentsWithTombstones.isEmpty()) { // Nothing to publish, only drop. We already validated that the intervalsToDrop do not have any // partially-overlapping segments, so it's safe to drop them as intervals instead of as specific segments. + // This should not need a segment load wait as segments are marked as unused immediately. for (final Interval interval : intervalsToDrop) { context.taskActionClient() .submit(new MarkSegmentsAsUnusedAction(task.getDataSource(), interval)); } } else { + Set versionsToAwait = segmentsWithTombstones.stream().map(DataSegment::getVersion).collect(Collectors.toSet()); + segmentLoadWaiter = new SegmentLoadWaiter( + context.injector().getInstance(BrokerClient.class), + context.jsonMapper(), + task.getDataSource(), + versionsToAwait, + segmentsWithTombstones.size(), + true + ); performSegmentPublish( context.taskActionClient(), SegmentTransactionalInsertAction.overwriteAction(null, segmentsWithTombstones) ); } } else if (!segments.isEmpty()) { + Set versionsToAwait = segments.stream().map(DataSegment::getVersion).collect(Collectors.toSet()); + segmentLoadWaiter = new SegmentLoadWaiter( + context.injector().getInstance(BrokerClient.class), + context.jsonMapper(), + task.getDataSource(), + versionsToAwait, + segments.size(), + true + ); // Append mode. performSegmentPublish( context.taskActionClient(), @@ -2072,7 +2128,8 @@ public class ControllerImpl implements Controller final Queue errorReports, @Nullable final DateTime queryStartTime, final long queryDuration, - MSQWorkerTaskLauncher taskLauncher + MSQWorkerTaskLauncher taskLauncher, + final SegmentLoadWaiter segmentLoadWaiter ) { int pendingTasks = -1; @@ -2083,6 +2140,9 @@ public class ControllerImpl implements Controller pendingTasks = workerTaskCount.getPendingWorkerCount(); runningTasks = workerTaskCount.getRunningWorkerCount() + 1; // To account for controller. } + + SegmentLoadWaiter.SegmentLoadWaiterStatus status = segmentLoadWaiter == null ? null : segmentLoadWaiter.status(); + return new MSQStatusReport( taskState, errorReport, @@ -2090,7 +2150,8 @@ public class ControllerImpl implements Controller queryStartTime, queryDuration, pendingTasks, - runningTasks + runningTasks, + status ); } @@ -2259,6 +2320,7 @@ public class ControllerImpl implements Controller throwKernelExceptionIfNotUnknown(); } + updateLiveReportMaps(); cleanUpEffectivelyFinishedStages(); return Pair.of(queryKernel, workerTaskLauncherFuture); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadWaiter.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadWaiter.java new file mode 100644 index 00000000000..3a54c41e410 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/SegmentLoadWaiter.java @@ -0,0 +1,422 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.druid.discovery.BrokerClient; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.java.util.http.client.Request; +import org.apache.druid.sql.http.ResultFormat; +import org.apache.druid.sql.http.SqlQuery; +import org.jboss.netty.handler.codec.http.HttpMethod; +import org.joda.time.DateTime; +import org.joda.time.Interval; + +import javax.annotation.Nullable; +import javax.ws.rs.core.MediaType; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; +import java.util.TreeSet; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Class that periodically checks with the broker if all the segments generated are loaded by querying the sys table + * and blocks till it is complete. This will account for and not wait for segments that would never be loaded due to + * load rules. Should only be called if the query generates new segments or tombstones. + *
+ * If an exception is thrown during operation, this will simply log the exception and exit without failing the task, + * since the segments have already been published successfully, and should be loaded eventually. + *
+ * If the segments are not loaded within {@link #TIMEOUT_DURATION_MILLIS} milliseconds, this logs a warning and exits + * for the same reason. + */ +public class SegmentLoadWaiter +{ + private static final Logger log = new Logger(SegmentLoadWaiter.class); + private static final long SLEEP_DURATION_MILLIS = TimeUnit.SECONDS.toMillis(5); + private static final long TIMEOUT_DURATION_MILLIS = TimeUnit.MINUTES.toMillis(10); + + /** + * The query sent to the broker. This query uses replication_factor to determine how many copies of a segment has to be + * loaded as per the load rules. + * - If a segment is not used, the broker will not have any information about it, hence, a COUNT(*) should return the used count only. + * - If replication_factor is more than 0, the segment will be loaded on historicals and needs to be waited for. + * - If replication_factor is 0, that means that the segment will never be loaded on a historical and does not need to + * be waited for. + * - If replication_factor is -1, the replication factor is not known currently and will become known after a load rule + * evaluation. + *
+ * See https://github.com/apache/druid/pull/14403 for more details about replication_factor + */ + private static final String LOAD_QUERY = "SELECT COUNT(*) AS usedSegments,\n" + + "COUNT(*) FILTER (WHERE is_published = 1 AND replication_factor > 0) AS precachedSegments,\n" + + "COUNT(*) FILTER (WHERE is_published = 1 AND replication_factor = 0) AS onDemandSegments,\n" + + "COUNT(*) FILTER (WHERE is_available = 0 AND is_published = 1 AND replication_factor != 0) AS pendingSegments,\n" + + "COUNT(*) FILTER (WHERE replication_factor = -1) AS unknownSegments\n" + + "FROM sys.segments\n" + + "WHERE datasource = '%s' AND is_overshadowed = 0 AND version = '%s'"; + + private final BrokerClient brokerClient; + private final ObjectMapper objectMapper; + // Map of version vs latest load status. + private final Map versionToLoadStatusMap; + private final String datasource; + private final Set versionsToAwait; + private final int totalSegmentsGenerated; + private final boolean doWait; + private final AtomicReference status; + + public SegmentLoadWaiter( + BrokerClient brokerClient, + ObjectMapper objectMapper, + String datasource, + Set versionsToAwait, + int totalSegmentsGenerated, + boolean doWait + ) + { + this.brokerClient = brokerClient; + this.objectMapper = objectMapper; + this.datasource = datasource; + this.versionsToAwait = new TreeSet<>(versionsToAwait); + this.versionToLoadStatusMap = new HashMap<>(); + this.totalSegmentsGenerated = totalSegmentsGenerated; + this.status = new AtomicReference<>(new SegmentLoadWaiterStatus(State.INIT, null, 0, totalSegmentsGenerated, 0, 0, 0, 0, totalSegmentsGenerated)); + this.doWait = doWait; + } + + /** + * Uses broker client to check if all segments created by the ingestion have been loaded and updates the {@link #status)} + * periodically. + *
+ * If an exception is thrown during operation, this will log the exception and return without failing the task, + * since the segments have already been published successfully, and should be loaded eventually. + *
+ * Only expected to be called from the main controller thread. + */ + public void waitForSegmentsToLoad() + { + DateTime startTime = DateTimes.nowUtc(); + boolean hasAnySegmentBeenLoaded = false; + + try { + while (!versionsToAwait.isEmpty()) { + // Check the timeout and exit if exceeded. + long runningMillis = new Interval(startTime, DateTimes.nowUtc()).toDurationMillis(); + if (runningMillis > TIMEOUT_DURATION_MILLIS) { + log.warn("Runtime [%s] exceeded timeout [%s] while waiting for segments to load. Exiting.", runningMillis, TIMEOUT_DURATION_MILLIS); + updateStatus(State.TIMED_OUT, startTime); + return; + } + + Iterator iterator = versionsToAwait.iterator(); + + // Query the broker for all pending versions + while (iterator.hasNext()) { + String version = iterator.next(); + + // Fetch the load status for this version from the broker + VersionLoadStatus loadStatus = fetchLoadStatusForVersion(version); + versionToLoadStatusMap.put(version, loadStatus); + + hasAnySegmentBeenLoaded = hasAnySegmentBeenLoaded || loadStatus.getUsedSegments() > 0; + + // If loading is done for this stage, remove it from future loops. + if (hasAnySegmentBeenLoaded && loadStatus.isLoadingComplete()) { + iterator.remove(); + } + } + + if (!versionsToAwait.isEmpty()) { + // Update the status. + updateStatus(State.WAITING, startTime); + + // Sleep for a while before retrying. + waitIfNeeded(SLEEP_DURATION_MILLIS); + } + } + } + catch (Exception e) { + log.warn(e, "Exception occurred while waiting for segments to load. Exiting."); + + // Update the status and return. + updateStatus(State.FAILED, startTime); + return; + } + // Update the status. + updateStatus(State.SUCCESS, startTime); + } + + private void waitIfNeeded(long waitTimeMillis) throws Exception + { + if (doWait) { + Thread.sleep(waitTimeMillis); + } + } + + /** + * Updates the {@link #status} with the latest details based on {@link #versionToLoadStatusMap} + */ + private void updateStatus(State state, DateTime startTime) + { + int pendingSegmentCount = 0, usedSegmentsCount = 0, precachedSegmentCount = 0, onDemandSegmentCount = 0, unknownSegmentCount = 0; + for (Map.Entry entry : versionToLoadStatusMap.entrySet()) { + usedSegmentsCount += entry.getValue().getUsedSegments(); + precachedSegmentCount += entry.getValue().getPrecachedSegments(); + onDemandSegmentCount += entry.getValue().getOnDemandSegments(); + unknownSegmentCount += entry.getValue().getUnknownSegments(); + pendingSegmentCount += entry.getValue().getPendingSegments(); + } + + long runningMillis = new Interval(startTime, DateTimes.nowUtc()).toDurationMillis(); + status.set( + new SegmentLoadWaiterStatus( + state, + startTime, + runningMillis, + totalSegmentsGenerated, + usedSegmentsCount, + precachedSegmentCount, + onDemandSegmentCount, + pendingSegmentCount, + unknownSegmentCount + ) + ); + } + + /** + * Uses {@link #brokerClient} to fetch latest load status for a given version. Converts the response into a + * {@link VersionLoadStatus} and returns it. + */ + private VersionLoadStatus fetchLoadStatusForVersion(String version) throws Exception + { + Request request = brokerClient.makeRequest(HttpMethod.POST, "/druid/v2/sql/"); + SqlQuery sqlQuery = new SqlQuery(StringUtils.format(LOAD_QUERY, datasource, version), + ResultFormat.OBJECTLINES, + false, false, false, null, null); + request.setContent(MediaType.APPLICATION_JSON, objectMapper.writeValueAsBytes(sqlQuery)); + + String response = brokerClient.sendQuery(request); + + if (response.trim().isEmpty()) { + // If no segments are returned for a version, all segments have been dropped by a drop rule. + return new VersionLoadStatus(0, 0, 0, 0, 0); + } else { + return objectMapper.readValue(response, VersionLoadStatus.class); + } + } + + /** + * Returns the current status of the load. + */ + public SegmentLoadWaiterStatus status() + { + return status.get(); + } + + public static class SegmentLoadWaiterStatus + { + private final State state; + private final DateTime startTime; + private final long duration; + private final int totalSegments; + private final int usedSegments; + private final int precachedSegments; + private final int onDemandSegments; + private final int pendingSegments; + private final int unknownSegments; + + @JsonCreator + public SegmentLoadWaiterStatus( + @JsonProperty("state") SegmentLoadWaiter.State state, + @JsonProperty("startTime") @Nullable DateTime startTime, + @JsonProperty("duration") long duration, + @JsonProperty("totalSegments") int totalSegments, + @JsonProperty("usedSegments") int usedSegments, + @JsonProperty("precachedSegments") int precachedSegments, + @JsonProperty("onDemandSegments") int onDemandSegments, + @JsonProperty("pendingSegments") int pendingSegments, + @JsonProperty("unknownSegments") int unknownSegments + ) + { + this.state = state; + this.startTime = startTime; + this.duration = duration; + this.totalSegments = totalSegments; + this.usedSegments = usedSegments; + this.precachedSegments = precachedSegments; + this.onDemandSegments = onDemandSegments; + this.pendingSegments = pendingSegments; + this.unknownSegments = unknownSegments; + } + + @JsonProperty + public SegmentLoadWaiter.State getState() + { + return state; + } + + @Nullable + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_NULL) + public DateTime getStartTime() + { + return startTime; + } + + @JsonProperty + public long getDuration() + { + return duration; + } + + @JsonProperty + public long getTotalSegments() + { + return totalSegments; + } + + @JsonProperty + public int getUsedSegments() + { + return usedSegments; + } + + @JsonProperty + public int getPrecachedSegments() + { + return precachedSegments; + } + + @JsonProperty + public int getOnDemandSegments() + { + return onDemandSegments; + } + + @JsonProperty + public int getPendingSegments() + { + return pendingSegments; + } + + @JsonProperty + public int getUnknownSegments() + { + return unknownSegments; + } + } + + public enum State + { + /** + * Initial state after being initialised with the segment versions and before #waitForSegmentsToLoad has been called. + */ + INIT, + /** + * All segments that need to be loaded have not yet been loaded. The load status is perodically being queried from + * the broker. + */ + WAITING, + /** + * All segments which need to be loaded have been loaded, and the SegmentLoadWaiter exited successfully. + */ + SUCCESS, + /** + * An exception occurred while checking load status. The SegmentLoadWaiter exited without failing the task. + */ + FAILED, + /** + * The time spent waiting for segments to load exceeded org.apache.druid.msq.exec.SegmentLoadWaiter#TIMEOUT_DURATION_MILLIS. + * The SegmentLoadWaiter exited without failing the task. + */ + TIMED_OUT + } + + public static class VersionLoadStatus + { + private final int usedSegments; + private final int precachedSegments; + private final int onDemandSegments; + private final int pendingSegments; + private final int unknownSegments; + + @JsonCreator + public VersionLoadStatus( + @JsonProperty("usedSegments") int usedSegments, + @JsonProperty("precachedSegments") int precachedSegments, + @JsonProperty("onDemandSegments") int onDemandSegments, + @JsonProperty("pendingSegments") int pendingSegments, + @JsonProperty("unknownSegments") int unknownSegments + ) + { + this.usedSegments = usedSegments; + this.precachedSegments = precachedSegments; + this.onDemandSegments = onDemandSegments; + this.pendingSegments = pendingSegments; + this.unknownSegments = unknownSegments; + } + + @JsonProperty + public int getUsedSegments() + { + return usedSegments; + } + + @JsonProperty + public int getPrecachedSegments() + { + return precachedSegments; + } + + @JsonProperty + public int getOnDemandSegments() + { + return onDemandSegments; + } + + @JsonProperty + public int getPendingSegments() + { + return pendingSegments; + } + + @JsonProperty + public int getUnknownSegments() + { + return unknownSegments; + } + + @JsonIgnore + public boolean isLoadingComplete() + { + return pendingSegments == 0 && (usedSegments == precachedSegments + onDemandSegments); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java index f391b08b671..dcc81d86864 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQWorkerTaskLauncher.java @@ -182,7 +182,19 @@ public class MSQWorkerTaskLauncher */ public void stop(final boolean interrupt) { - if (state.compareAndSet(State.STARTED, State.STOPPED)) { + if (state.compareAndSet(State.NEW, State.STOPPED)) { + state.set(State.STOPPED); + if (interrupt) { + cancelTasksOnStop.set(true); + } + + synchronized (taskIds) { + // Wake up sleeping mainLoop. + taskIds.notifyAll(); + } + exec.shutdown(); + stopFuture.set(null); + } else if (state.compareAndSet(State.STARTED, State.STOPPED)) { if (interrupt) { cancelTasksOnStop.set(true); } @@ -466,9 +478,13 @@ public class MSQWorkerTaskLauncher public WorkerCount getWorkerTaskCount() { synchronized (taskIds) { - int runningTasks = fullyStartedTasks.size(); - int pendingTasks = desiredTaskCount - runningTasks; - return new WorkerCount(runningTasks, pendingTasks); + if (stopFuture.isDone()) { + return new WorkerCount(0, 0); + } else { + int runningTasks = fullyStartedTasks.size(); + int pendingTasks = desiredTaskCount - runningTasks; + return new WorkerCount(runningTasks, pendingTasks); + } } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStatusReport.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStatusReport.java index 3791bc82e16..ea721d84f47 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStatusReport.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQStatusReport.java @@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Preconditions; import org.apache.druid.indexer.TaskState; +import org.apache.druid.msq.exec.SegmentLoadWaiter; import org.apache.druid.msq.indexing.error.MSQErrorReport; import org.joda.time.DateTime; @@ -50,6 +51,9 @@ public class MSQStatusReport private final int runningTasks; + @Nullable + private final SegmentLoadWaiter.SegmentLoadWaiterStatus segmentLoadWaiterStatus; + @JsonCreator public MSQStatusReport( @JsonProperty("status") TaskState status, @@ -58,7 +62,8 @@ public class MSQStatusReport @JsonProperty("startTime") @Nullable DateTime startTime, @JsonProperty("durationMs") long durationMs, @JsonProperty("pendingTasks") int pendingTasks, - @JsonProperty("runningTasks") int runningTasks + @JsonProperty("runningTasks") int runningTasks, + @JsonProperty("segmentLoadWaiterStatus") @Nullable SegmentLoadWaiter.SegmentLoadWaiterStatus segmentLoadWaiterStatus ) { this.status = Preconditions.checkNotNull(status, "status"); @@ -68,6 +73,7 @@ public class MSQStatusReport this.durationMs = durationMs; this.pendingTasks = pendingTasks; this.runningTasks = runningTasks; + this.segmentLoadWaiterStatus = segmentLoadWaiterStatus; } @JsonProperty @@ -117,6 +123,14 @@ public class MSQStatusReport return durationMs; } + @Nullable + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_NULL) + public SegmentLoadWaiter.SegmentLoadWaiterStatus getSegmentLoadWaiterStatus() + { + return segmentLoadWaiterStatus; + } + @Override public boolean equals(Object o) { diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadWaiterTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadWaiterTest.java new file mode 100644 index 00000000000..e14fa5faec2 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/SegmentLoadWaiterTest.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableSet; +import org.apache.druid.discovery.BrokerClient; +import org.apache.druid.java.util.http.client.Request; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class SegmentLoadWaiterTest +{ + private static final String TEST_DATASOURCE = "testDatasource"; + + private SegmentLoadWaiter segmentLoadWaiter; + + private BrokerClient brokerClient; + + /** + * Single version created, loaded after 3 attempts + */ + @Test + public void testSingleVersionWaitsForLoadCorrectly() throws Exception + { + brokerClient = mock(BrokerClient.class); + + doReturn(mock(Request.class)).when(brokerClient).makeRequest(any(), anyString()); + doAnswer(new Answer() + { + int timesInvoked = 0; + @Override + public String answer(InvocationOnMock invocation) throws Throwable + { + timesInvoked += 1; + SegmentLoadWaiter.VersionLoadStatus loadStatus = new SegmentLoadWaiter.VersionLoadStatus(5, timesInvoked, 0, 5 - timesInvoked, 0); + return new ObjectMapper().writeValueAsString(loadStatus); + } + }).when(brokerClient).sendQuery(any()); + segmentLoadWaiter = new SegmentLoadWaiter(brokerClient, new ObjectMapper(), TEST_DATASOURCE, ImmutableSet.of("version1"), 5, false); + segmentLoadWaiter.waitForSegmentsToLoad(); + + verify(brokerClient, times(5)).sendQuery(any()); + } + + @Test + public void testMultipleVersionWaitsForLoadCorrectly() throws Exception + { + brokerClient = mock(BrokerClient.class); + + doReturn(mock(Request.class)).when(brokerClient).makeRequest(any(), anyString()); + doAnswer(new Answer() + { + int timesInvoked = 0; + @Override + public String answer(InvocationOnMock invocation) throws Throwable + { + timesInvoked += 1; + SegmentLoadWaiter.VersionLoadStatus loadStatus = new SegmentLoadWaiter.VersionLoadStatus(5, timesInvoked, 0, 5 - timesInvoked, 0); + return new ObjectMapper().writeValueAsString(loadStatus); + } + }).when(brokerClient).sendQuery(any()); + segmentLoadWaiter = new SegmentLoadWaiter(brokerClient, new ObjectMapper(), TEST_DATASOURCE, ImmutableSet.of("version1"), 5, false); + segmentLoadWaiter.waitForSegmentsToLoad(); + + verify(brokerClient, times(5)).sendQuery(any()); + } + +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java index eeac9486dc9..ef50008d48e 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java @@ -30,10 +30,12 @@ import org.apache.druid.frame.key.KeyOrder; import org.apache.druid.indexer.TaskState; import org.apache.druid.indexing.common.SingleFileTaskReportFileWriter; import org.apache.druid.indexing.common.TaskReport; +import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.msq.counters.CounterSnapshotsTree; +import org.apache.druid.msq.exec.SegmentLoadWaiter; import org.apache.druid.msq.guice.MSQIndexingModule; import org.apache.druid.msq.indexing.error.MSQErrorReport; import org.apache.druid.msq.indexing.error.TooManyColumnsFault; @@ -90,10 +92,22 @@ public class MSQTaskReportTest new Object[]{"bar"} ); + SegmentLoadWaiter.SegmentLoadWaiterStatus status = new SegmentLoadWaiter.SegmentLoadWaiterStatus( + SegmentLoadWaiter.State.WAITING, + DateTimes.nowUtc(), + 200L, + 100, + 80, + 30, + 50, + 10, + 0 + ); + final MSQTaskReport report = new MSQTaskReport( TASK_ID, new MSQTaskReportPayload( - new MSQStatusReport(TaskState.SUCCESS, null, new ArrayDeque<>(), null, 0, 1, 2), + new MSQStatusReport(TaskState.SUCCESS, null, new ArrayDeque<>(), null, 0, 1, 2, status), MSQStagesReport.create( QUERY_DEFINITION, ImmutableMap.of(), @@ -142,11 +156,23 @@ public class MSQTaskReportTest @Test public void testSerdeErrorReport() throws Exception { + SegmentLoadWaiter.SegmentLoadWaiterStatus status = new SegmentLoadWaiter.SegmentLoadWaiterStatus( + SegmentLoadWaiter.State.FAILED, + DateTimes.nowUtc(), + 200L, + 100, + 80, + 30, + 50, + 10, + 0 + ); + final MSQErrorReport errorReport = MSQErrorReport.fromFault(TASK_ID, HOST, 0, new TooManyColumnsFault(10, 5)); final MSQTaskReport report = new MSQTaskReport( TASK_ID, new MSQTaskReportPayload( - new MSQStatusReport(TaskState.FAILED, errorReport, new ArrayDeque<>(), null, 0, 1, 2), + new MSQStatusReport(TaskState.FAILED, errorReport, new ArrayDeque<>(), null, 0, 1, 2, status), MSQStagesReport.create( QUERY_DEFINITION, ImmutableMap.of(), @@ -179,10 +205,22 @@ public class MSQTaskReportTest @Test public void testWriteTaskReport() throws Exception { + SegmentLoadWaiter.SegmentLoadWaiterStatus status = new SegmentLoadWaiter.SegmentLoadWaiterStatus( + SegmentLoadWaiter.State.SUCCESS, + DateTimes.nowUtc(), + 200L, + 100, + 80, + 30, + 50, + 10, + 0 + ); + final MSQTaskReport report = new MSQTaskReport( TASK_ID, new MSQTaskReportPayload( - new MSQStatusReport(TaskState.SUCCESS, null, new ArrayDeque<>(), null, 0, 1, 2), + new MSQStatusReport(TaskState.SUCCESS, null, new ArrayDeque<>(), null, 0, 1, 2, status), MSQStagesReport.create( QUERY_DEFINITION, ImmutableMap.of(), diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlStatementResourceTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlStatementResourceTest.java index a603dcb9173..ec986a93159 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlStatementResourceTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlStatementResourceTest.java @@ -240,7 +240,8 @@ public class SqlStatementResourceTest extends MSQTestBase null, 0, 1, - 2 + 2, + null ), MSQStagesReport.create( MSQTaskReportTest.QUERY_DEFINITION, @@ -305,7 +306,8 @@ public class SqlStatementResourceTest extends MSQTestBase null, 0, 1, - 2 + 2, + null ), MSQStagesReport.create( MSQTaskReportTest.QUERY_DEFINITION, diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index 3ce2d18e40d..9ebcb2ec53e 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -42,6 +42,7 @@ import org.apache.druid.common.config.NullHandling; import org.apache.druid.data.input.impl.DimensionsSpec; import org.apache.druid.data.input.impl.LongDimensionSchema; import org.apache.druid.data.input.impl.StringDimensionSchema; +import org.apache.druid.discovery.BrokerClient; import org.apache.druid.discovery.NodeRole; import org.apache.druid.frame.channel.FrameChannelSequence; import org.apache.druid.frame.testutil.FrameTestUtil; @@ -73,6 +74,7 @@ import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.java.util.http.client.Request; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.metadata.input.InputSourceModule; import org.apache.druid.msq.counters.CounterNames; @@ -210,6 +212,10 @@ import static org.apache.druid.sql.calcite.util.CalciteTests.DATASOURCE1; import static org.apache.druid.sql.calcite.util.CalciteTests.DATASOURCE2; import static org.apache.druid.sql.calcite.util.TestDataBuilder.ROWS1; import static org.apache.druid.sql.calcite.util.TestDataBuilder.ROWS2; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; /** * Base test runner for running MSQ unit tests. It sets up multi stage query execution environment @@ -352,7 +358,7 @@ public class MSQTestBase extends BaseCalciteQueryTest // which depends on the object mapper that the injector will provide, once it // is built, but has not yet been build while we build the SQL engine. @Before - public void setUp2() + public void setUp2() throws Exception { groupByBuffers = TestGroupByBuffers.createDefault(); @@ -380,6 +386,7 @@ public class MSQTestBase extends BaseCalciteQueryTest segmentManager = new MSQTestSegmentManager(segmentCacheManager, indexIO); + BrokerClient brokerClient = mock(BrokerClient.class); List modules = ImmutableList.of( binder -> { DruidProcessingConfig druidProcessingConfig = new DruidProcessingConfig() @@ -467,7 +474,8 @@ public class MSQTestBase extends BaseCalciteQueryTest ), new MSQExternalDataSourceModule(), new LookylooModule(), - new SegmentWranglerModule() + new SegmentWranglerModule(), + binder -> binder.bind(BrokerClient.class).toInstance(brokerClient) ); // adding node role injection to the modules, since CliPeon would also do that through run method Injector injector = new CoreInjectorBuilder(new StartupInjectorBuilder().build(), ImmutableSet.of(NodeRole.PEON)) @@ -477,6 +485,8 @@ public class MSQTestBase extends BaseCalciteQueryTest objectMapper = setupObjectMapper(injector); objectMapper.registerModules(sqlModule.getJacksonModules()); + doReturn(mock(Request.class)).when(brokerClient).makeRequest(any(), anyString()); + testTaskActionClient = Mockito.spy(new MSQTestTaskActionClient(objectMapper)); indexingServiceClient = new MSQTestOverlordServiceClient( objectMapper, diff --git a/server/src/main/java/org/apache/druid/discovery/BrokerClient.java b/server/src/main/java/org/apache/druid/discovery/BrokerClient.java new file mode 100644 index 00000000000..bc97c2490ef --- /dev/null +++ b/server/src/main/java/org/apache/druid/discovery/BrokerClient.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.discovery; + +import com.google.inject.Inject; +import org.apache.druid.error.DruidException; +import org.apache.druid.guice.annotations.EscalatedGlobal; +import org.apache.druid.java.util.common.IOE; +import org.apache.druid.java.util.common.RetryUtils; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.http.client.HttpClient; +import org.apache.druid.java.util.http.client.Request; +import org.apache.druid.java.util.http.client.response.StringFullResponseHandler; +import org.apache.druid.java.util.http.client.response.StringFullResponseHolder; +import org.jboss.netty.channel.ChannelException; +import org.jboss.netty.handler.codec.http.HttpMethod; +import org.jboss.netty.handler.codec.http.HttpResponseStatus; + +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.ExecutionException; + +/** + * This class facilitates interaction with Broker. + */ +public class BrokerClient +{ + private static final int MAX_RETRIES = 5; + + private final HttpClient brokerHttpClient; + private final DruidNodeDiscovery druidNodeDiscovery; + + @Inject + public BrokerClient( + @EscalatedGlobal HttpClient brokerHttpClient, + DruidNodeDiscoveryProvider druidNodeDiscoveryProvider + ) + { + this.brokerHttpClient = brokerHttpClient; + this.druidNodeDiscovery = druidNodeDiscoveryProvider.getForNodeRole(NodeRole.BROKER); + } + + /** + * Creates and returns a {@link Request} after choosing a broker. + */ + public Request makeRequest(HttpMethod httpMethod, String urlPath) throws IOException + { + String host = ClientUtils.pickOneHost(druidNodeDiscovery); + + if (host == null) { + throw DruidException.forPersona(DruidException.Persona.ADMIN) + .ofCategory(DruidException.Category.NOT_FOUND) + .build("A leader node could not be found for [%s] service. Check the logs to validate that service is healthy.", NodeRole.BROKER); + } + return new Request(httpMethod, new URL(StringUtils.format("%s%s", host, urlPath))); + } + + public String sendQuery(final Request request) throws Exception + { + return RetryUtils.retry( + () -> { + Request newRequestUrl = getNewRequestUrl(request); + final StringFullResponseHolder fullResponseHolder = brokerHttpClient.go(newRequestUrl, new StringFullResponseHandler(StandardCharsets.UTF_8)).get(); + + HttpResponseStatus responseStatus = fullResponseHolder.getResponse().getStatus(); + if (HttpResponseStatus.SERVICE_UNAVAILABLE.equals(responseStatus) + || HttpResponseStatus.GATEWAY_TIMEOUT.equals(responseStatus)) { + throw DruidException.forPersona(DruidException.Persona.OPERATOR) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .build("Request to broker failed due to failed response status: [%s]", responseStatus); + } else if (responseStatus.getCode() != HttpServletResponse.SC_OK) { + throw DruidException.forPersona(DruidException.Persona.OPERATOR) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .build("Request to broker failed due to failed response code: [%s]", responseStatus.getCode()); + } + return fullResponseHolder.getContent(); + }, + (throwable) -> { + if (throwable instanceof ExecutionException) { + return throwable.getCause() instanceof IOException || throwable.getCause() instanceof ChannelException; + } + return throwable instanceof IOE; + }, + MAX_RETRIES + ); + } + + private Request getNewRequestUrl(Request oldRequest) + { + try { + return ClientUtils.withUrl( + oldRequest, + new URL(StringUtils.format("%s%s", ClientUtils.pickOneHost(druidNodeDiscovery), oldRequest.getUrl().getPath())) + ); + } + catch (MalformedURLException e) { + // Not an IOException; this is our own fault. + throw DruidException.defensive( + "Failed to build url with path[%] and query string [%s].", + oldRequest.getUrl().getPath(), + oldRequest.getUrl().getQuery() + ); + } + } +} diff --git a/server/src/main/java/org/apache/druid/discovery/ClientUtils.java b/server/src/main/java/org/apache/druid/discovery/ClientUtils.java new file mode 100644 index 00000000000..b9c53343c0d --- /dev/null +++ b/server/src/main/java/org/apache/druid/discovery/ClientUtils.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.discovery; + +import com.google.common.collect.Lists; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.http.client.Request; + +import javax.annotation.Nullable; +import java.net.URL; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; + +/** + * Utils class for shared client methods + */ +public class ClientUtils +{ + @Nullable + public static String pickOneHost(DruidNodeDiscovery druidNodeDiscovery) + { + Iterator iter = druidNodeDiscovery.getAllNodes().iterator(); + List discoveryDruidNodeList = Lists.newArrayList(iter); + if (!discoveryDruidNodeList.isEmpty()) { + DiscoveryDruidNode node = discoveryDruidNodeList.get(ThreadLocalRandom.current().nextInt(discoveryDruidNodeList.size())); + return StringUtils.format( + "%s://%s", + node.getDruidNode().getServiceScheme(), + node.getDruidNode().getHostAndPortToUse() + ); + } + return null; + } + + public static Request withUrl(Request old, URL url) + { + Request req = new Request(old.getMethod(), url); + req.addHeaderValues(old.getHeaders()); + if (old.hasContent()) { + req.setContent(old.getContent().copy()); + } + return req; + } +} diff --git a/server/src/main/java/org/apache/druid/discovery/DruidLeaderClient.java b/server/src/main/java/org/apache/druid/discovery/DruidLeaderClient.java index d681004326c..4ca1441f6f2 100644 --- a/server/src/main/java/org/apache/druid/discovery/DruidLeaderClient.java +++ b/server/src/main/java/org/apache/druid/discovery/DruidLeaderClient.java @@ -39,12 +39,10 @@ import org.jboss.netty.channel.ChannelException; import org.jboss.netty.handler.codec.http.HttpMethod; import org.jboss.netty.handler.codec.http.HttpResponseStatus; -import javax.annotation.Nullable; import java.io.IOException; import java.net.MalformedURLException; import java.net.URL; import java.nio.charset.StandardCharsets; -import java.util.Iterator; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -202,7 +200,7 @@ public class DruidLeaderClient redirectUrl.getPort() )); - request = withUrl(request, redirectUrl); + request = ClientUtils.withUrl(request, redirectUrl); } else if (HttpResponseStatus.SERVICE_UNAVAILABLE.equals(responseStatus) || HttpResponseStatus.GATEWAY_TIMEOUT.equals(responseStatus)) { log.warn( @@ -260,7 +258,7 @@ public class DruidLeaderClient { final String leader = currentKnownLeader.accumulateAndGet( null, - (current, given) -> current == null || !cached ? pickOneHost() : current + (current, given) -> current == null || !cached ? ClientUtils.pickOneHost(druidNodeDiscovery) : current ); if (leader == null) { @@ -274,43 +272,17 @@ public class DruidLeaderClient } } - @Nullable - private String pickOneHost() - { - Iterator iter = druidNodeDiscovery.getAllNodes().iterator(); - if (iter.hasNext()) { - DiscoveryDruidNode node = iter.next(); - return StringUtils.format( - "%s://%s", - node.getDruidNode().getServiceScheme(), - node.getDruidNode().getHostAndPortToUse() - ); - } - - return null; - } - - private Request withUrl(Request old, URL url) - { - Request req = new Request(old.getMethod(), url); - req.addHeaderValues(old.getHeaders()); - if (old.hasContent()) { - req.setContent(old.getContent()); - } - return req; - } - private Request getNewRequestUrlInvalidatingCache(Request oldRequest) throws IOException { try { Request newRequest; if (oldRequest.getUrl().getQuery() == null) { - newRequest = withUrl( + newRequest = ClientUtils.withUrl( oldRequest, new URL(StringUtils.format("%s%s", getCurrentKnownLeader(false), oldRequest.getUrl().getPath())) ); } else { - newRequest = withUrl( + newRequest = ClientUtils.withUrl( oldRequest, new URL(StringUtils.format( "%s%s?%s", diff --git a/server/src/test/java/org/apache/druid/discovery/BrokerClientTest.java b/server/src/test/java/org/apache/druid/discovery/BrokerClientTest.java new file mode 100644 index 00000000000..333882a4305 --- /dev/null +++ b/server/src/test/java/org/apache/druid/discovery/BrokerClientTest.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.discovery; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.inject.Injector; +import com.google.inject.Key; +import com.google.inject.name.Names; +import org.apache.druid.guice.GuiceInjectors; +import org.apache.druid.guice.Jerseys; +import org.apache.druid.guice.JsonConfigProvider; +import org.apache.druid.guice.LazySingleton; +import org.apache.druid.guice.LifecycleModule; +import org.apache.druid.guice.annotations.Self; +import org.apache.druid.initialization.Initialization; +import org.apache.druid.java.util.http.client.HttpClient; +import org.apache.druid.java.util.http.client.Request; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.initialization.BaseJettyTest; +import org.apache.druid.server.initialization.jetty.JettyServerInitializer; +import org.easymock.EasyMock; +import org.eclipse.jetty.server.Server; +import org.jboss.netty.handler.codec.http.HttpMethod; +import org.junit.Assert; +import org.junit.Test; + +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.Produces; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import java.nio.charset.StandardCharsets; + +public class BrokerClientTest extends BaseJettyTest +{ + private DiscoveryDruidNode discoveryDruidNode; + private HttpClient httpClient; + + @Override + protected Injector setupInjector() + { + final DruidNode node = new DruidNode("test", "localhost", false, null, null, true, false); + discoveryDruidNode = new DiscoveryDruidNode(node, NodeRole.BROKER, ImmutableMap.of()); + + Injector injector = Initialization.makeInjectorWithModules( + GuiceInjectors.makeStartupInjector(), ImmutableList.of( + binder -> { + JsonConfigProvider.bindInstance( + binder, + Key.get(DruidNode.class, Self.class), + node + ); + binder.bind(Integer.class).annotatedWith(Names.named("port")).toInstance(node.getPlaintextPort()); + binder.bind(JettyServerInitializer.class).to(DruidLeaderClientTest.TestJettyServerInitializer.class).in(LazySingleton.class); + Jerseys.addResource(binder, SimpleResource.class); + LifecycleModule.register(binder, Server.class); + } + ) + ); + httpClient = injector.getInstance(BaseJettyTest.ClientHolder.class).getClient(); + return injector; + } + + @Test + public void testSimple() throws Exception + { + DruidNodeDiscovery druidNodeDiscovery = EasyMock.createMock(DruidNodeDiscovery.class); + EasyMock.expect(druidNodeDiscovery.getAllNodes()).andReturn(ImmutableList.of(discoveryDruidNode)).anyTimes(); + + DruidNodeDiscoveryProvider druidNodeDiscoveryProvider = EasyMock.createMock(DruidNodeDiscoveryProvider.class); + EasyMock.expect(druidNodeDiscoveryProvider.getForNodeRole(NodeRole.BROKER)).andReturn(druidNodeDiscovery); + + EasyMock.replay(druidNodeDiscovery, druidNodeDiscoveryProvider); + + BrokerClient brokerClient = new BrokerClient( + httpClient, + druidNodeDiscoveryProvider + ); + + Request request = brokerClient.makeRequest(HttpMethod.POST, "/simple/direct"); + request.setContent("hello".getBytes(StandardCharsets.UTF_8)); + Assert.assertEquals("hello", brokerClient.sendQuery(request)); + } + + @Test + public void testError() throws Exception + { + DruidNodeDiscovery druidNodeDiscovery = EasyMock.createMock(DruidNodeDiscovery.class); + EasyMock.expect(druidNodeDiscovery.getAllNodes()).andReturn(ImmutableList.of(discoveryDruidNode)).anyTimes(); + + DruidNodeDiscoveryProvider druidNodeDiscoveryProvider = EasyMock.createMock(DruidNodeDiscoveryProvider.class); + EasyMock.expect(druidNodeDiscoveryProvider.getForNodeRole(NodeRole.BROKER)).andReturn(druidNodeDiscovery); + + EasyMock.replay(druidNodeDiscovery, druidNodeDiscoveryProvider); + + BrokerClient brokerClient = new BrokerClient( + httpClient, + druidNodeDiscoveryProvider + ); + + Request request = brokerClient.makeRequest(HttpMethod.POST, "/simple/flakey"); + request.setContent("hello".getBytes(StandardCharsets.UTF_8)); + Assert.assertEquals("hello", brokerClient.sendQuery(request)); + } + + @Path("/simple") + public static class SimpleResource + { + private static int attempt = 0; + + @POST + @Path("/direct") + @Produces(MediaType.APPLICATION_JSON) + public Response direct(String input) + { + if ("hello".equals(input)) { + return Response.ok("hello").build(); + } else { + return Response.serverError().build(); + } + } + + @POST + @Path("/flakey") + @Produces(MediaType.APPLICATION_JSON) + public Response redirecting() + { + if (attempt > 2) { + return Response.ok("hello").build(); + } else { + attempt += 1; + return Response.status(504).build(); + } + } + } +} diff --git a/server/src/test/java/org/apache/druid/discovery/DruidLeaderClientTest.java b/server/src/test/java/org/apache/druid/discovery/DruidLeaderClientTest.java index f0f91469dc8..4d9870c6e6e 100644 --- a/server/src/test/java/org/apache/druid/discovery/DruidLeaderClientTest.java +++ b/server/src/test/java/org/apache/druid/discovery/DruidLeaderClientTest.java @@ -281,7 +281,7 @@ public class DruidLeaderClientTest extends BaseJettyTest Assert.assertEquals("http://localhost:1234/", druidLeaderClient.findCurrentLeader()); } - private static class TestJettyServerInitializer implements JettyServerInitializer + static class TestJettyServerInitializer implements JettyServerInitializer { @Override public void initialize(Server server, Injector injector) diff --git a/website/.spelling b/website/.spelling index cc4e02fcf29..d08c87c12bd 100644 --- a/website/.spelling +++ b/website/.spelling @@ -440,6 +440,7 @@ preemptible prefetch prefetched prefetching +precached prepend prepended prepending @@ -747,6 +748,7 @@ TooManyWorkers NotEnoughMemory WorkerFailed WorkerRpcFailed +TIMED_OUT # MSQ context parameters maxNumTasks taskAssignment