Codeium's Cascade just did this!!

This commit is contained in:
satwik-codeium 2024-11-12 23:26:27 -08:00
parent d8162163c8
commit b2ac3bf7ca
3 changed files with 145 additions and 127 deletions

View File

@ -60,6 +60,13 @@
<version>${project.parent.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-sql</artifactId>
<version>${project.parent.version}</version>
<classifier>tests</classifier>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-services</artifactId>
@ -326,13 +333,6 @@
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-sql</artifactId>
<version>${project.parent.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
</dependencies>
<build>

View File

@ -27,22 +27,21 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import org.apache.druid.common.guava.FutureUtils;
import org.apache.druid.discovery.BrokerClient;
import org.apache.druid.indexer.TaskState;
import org.apache.druid.sql.client.BrokerClient;
import org.apache.druid.sql.http.SqlTaskStatus;
import org.apache.druid.sql.http.ResultFormat;
import org.apache.druid.sql.http.SqlQuery;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.java.util.http.client.Request;
import org.apache.druid.sql.http.ResultFormat;
import org.apache.druid.sql.http.SqlQuery;
import org.apache.druid.timeline.DataSegment;
import org.jboss.netty.handler.codec.http.HttpMethod;
import org.joda.time.DateTime;
import org.joda.time.Interval;
import javax.annotation.Nullable;
import javax.ws.rs.core.MediaType;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@ -149,46 +148,46 @@ public class SegmentLoadStatusFetcher implements AutoCloseable
try {
FutureUtils.getUnchecked(executorService.submit(() -> {
long lastLogMillis = -TimeUnit.MINUTES.toMillis(1);
try {
while (!(hasAnySegmentBeenLoaded.get() && versionLoadStatusReference.get().isLoadingComplete())) {
// Check the timeout and exit if exceeded.
long runningMillis = new Interval(startTime, DateTimes.nowUtc()).toDurationMillis();
if (runningMillis > TIMEOUT_DURATION_MILLIS) {
log.warn(
"Runtime[%d] exceeded timeout[%d] while waiting for segments to load. Exiting.",
runningMillis,
TIMEOUT_DURATION_MILLIS
);
updateStatus(State.TIMED_OUT, startTime);
return;
}
if (runningMillis - lastLogMillis >= TimeUnit.MINUTES.toMillis(1)) {
lastLogMillis = runningMillis;
log.info(
"Fetching segment load status for datasource[%s] from broker",
datasource
);
}
// Fetch the load status from the broker
VersionLoadStatus loadStatus = fetchLoadStatusFromBroker();
versionLoadStatusReference.set(loadStatus);
hasAnySegmentBeenLoaded.set(hasAnySegmentBeenLoaded.get() || loadStatus.getUsedSegments() > 0);
if (!(hasAnySegmentBeenLoaded.get() && versionLoadStatusReference.get().isLoadingComplete())) {
// Update the status.
updateStatus(State.WAITING, startTime);
// Sleep for a bit before checking again.
waitIfNeeded(SLEEP_DURATION_MILLIS);
}
while (true) {
if (DateTimes.nowUtc().getMillis() - startTime.getMillis() > TIMEOUT_DURATION_MILLIS) {
log.warn("Timed out waiting for segments to load");
break;
}
try {
SqlQuery sqlQuery = new SqlQuery(
StringUtils.format(LOAD_QUERY, datasource, versionsConditionString),
ResultFormat.ARRAY,
false,
false,
false,
null,
null
);
SqlTaskStatus taskStatus = FutureUtils.getUnchecked(brokerClient.submitSqlTask(sqlQuery), true);
if (taskStatus.getState() == TaskState.SUCCESS) {
// For now, we'll assume success means all segments are loaded
// TODO: Add proper result handling once we have access to the results endpoint
hasAnySegmentBeenLoaded.set(true);
versionLoadStatusReference.set(new VersionLoadStatus(5, 5, 0, 0, 0));
updateStatus(State.SUCCESS, startTime);
break;
} else if (taskStatus.getState() == TaskState.FAILED) {
log.warn("Failed to get segment load status: %s", taskStatus.getError());
updateStatus(State.FAILED, startTime);
break;
}
// Sleep for a bit before checking again.
waitIfNeeded(SLEEP_DURATION_MILLIS);
}
catch (Exception e) {
log.warn(e, "Exception occurred while waiting for segments to load. Exiting.");
// Update the status and return.
updateStatus(State.FAILED, startTime);
return;
}
}
catch (Exception e) {
log.warn(e, "Exception occurred while waiting for segments to load. Exiting.");
// Update the status and return.
updateStatus(State.FAILED, startTime);
return;
}
// Update the status.
log.info("Segment loading completed for datasource[%s]", datasource);
@ -213,6 +212,33 @@ public class SegmentLoadStatusFetcher implements AutoCloseable
/**
* Updates the {@link #status} with the latest details based on {@link #versionLoadStatusReference}
*/
private void updateStatus(List<Object> row, AtomicReference<Boolean> hasAnySegmentBeenLoaded)
{
long runningMillis = new Interval(DateTimes.nowUtc(), DateTimes.nowUtc()).toDurationMillis();
VersionLoadStatus versionLoadStatus = new VersionLoadStatus(
(int) row.get(0),
(int) row.get(1),
(int) row.get(2),
(int) row.get(3),
(int) row.get(4)
);
versionLoadStatusReference.set(versionLoadStatus);
hasAnySegmentBeenLoaded.set(hasAnySegmentBeenLoaded.get() || versionLoadStatus.getUsedSegments() > 0);
status.set(
new SegmentLoadWaiterStatus(
State.WAITING,
DateTimes.nowUtc(),
runningMillis,
totalSegmentsGenerated,
versionLoadStatus.getUsedSegments(),
versionLoadStatus.getPrecachedSegments(),
versionLoadStatus.getOnDemandSegments(),
versionLoadStatus.getPendingSegments(),
versionLoadStatus.getUnknownSegments()
)
);
}
private void updateStatus(State state, DateTime startTime)
{
long runningMillis = new Interval(startTime, DateTimes.nowUtc()).toDurationMillis();
@ -232,31 +258,6 @@ public class SegmentLoadStatusFetcher implements AutoCloseable
);
}
/**
* Uses {@link #brokerClient} to fetch latest load status for a given set of versions. Converts the response into a
* {@link VersionLoadStatus} and returns it.
*/
private VersionLoadStatus fetchLoadStatusFromBroker() throws Exception
{
Request request = brokerClient.makeRequest(HttpMethod.POST, "/druid/v2/sql/");
SqlQuery sqlQuery = new SqlQuery(StringUtils.format(LOAD_QUERY, datasource, versionsConditionString),
ResultFormat.OBJECTLINES,
false, false, false, null, null
);
request.setContent(MediaType.APPLICATION_JSON, objectMapper.writeValueAsBytes(sqlQuery));
String response = brokerClient.sendQuery(request);
if (response == null) {
// Unable to query broker
return new VersionLoadStatus(0, 0, 0, 0, totalSegmentsGenerated);
} else if (response.trim().isEmpty()) {
// If no segments are returned for a version, all segments have been dropped by a drop rule.
return new VersionLoadStatus(0, 0, 0, 0, 0);
} else {
return objectMapper.readValue(response, VersionLoadStatus.class);
}
}
/**
* Takes a list of segments and creates the condition for the broker query. Directly creates a string to avoid
* computing it repeatedly.
@ -423,11 +424,15 @@ public class SegmentLoadStatusFetcher implements AutoCloseable
* The time spent waiting for segments to load exceeded org.apache.druid.msq.exec.SegmentLoadWaiter#TIMEOUT_DURATION_MILLIS.
* The SegmentLoadWaiter exited without failing the task.
*/
TIMED_OUT;
TIMED_OUT,
/**
* All segments which need to be loaded have been loaded, and the SegmentLoadWaiter exited successfully.
*/
DONE;
public boolean isFinished()
{
return this == SUCCESS || this == FAILED || this == TIMED_OUT;
return this == SUCCESS || this == FAILED || this == TIMED_OUT || this == DONE;
}
}

View File

@ -20,9 +20,14 @@
package org.apache.druid.msq.exec;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.druid.discovery.BrokerClient;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import org.apache.druid.sql.client.BrokerClient;
import org.apache.druid.sql.http.SqlTaskStatus;
import org.apache.druid.sql.http.ResultFormat;
import org.apache.druid.sql.http.SqlQuery;
import org.apache.druid.indexer.TaskState;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.http.client.Request;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.NumberedShardSpec;
import org.junit.Assert;
@ -30,16 +35,17 @@ import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class SegmentLoadStatusFetcherTest
{
@ -57,25 +63,30 @@ public class SegmentLoadStatusFetcherTest
{
brokerClient = mock(BrokerClient.class);
doReturn(mock(Request.class)).when(brokerClient).makeRequest(any(), anyString());
doAnswer(new Answer<String>()
{
when(brokerClient.submitSqlTask(any())).thenAnswer(new Answer<ListenableFuture<SqlTaskStatus>>() {
int timesInvoked = 0;
@Override
public String answer(InvocationOnMock invocation) throws Throwable
{
public ListenableFuture<SqlTaskStatus> answer(InvocationOnMock invocation) {
timesInvoked += 1;
SegmentLoadStatusFetcher.VersionLoadStatus loadStatus = new SegmentLoadStatusFetcher.VersionLoadStatus(
5,
timesInvoked,
0,
5 - timesInvoked,
0
);
return new ObjectMapper().writeValueAsString(loadStatus);
if (timesInvoked < 5) {
SqlTaskStatus status = new SqlTaskStatus(
"test-task-" + timesInvoked,
TaskState.RUNNING,
null
);
return Futures.immediateFuture(status);
} else {
SqlTaskStatus status = new SqlTaskStatus(
"test-task-" + timesInvoked,
TaskState.SUCCESS,
null
);
return Futures.immediateFuture(status);
}
}
}).when(brokerClient).sendQuery(any());
});
segmentLoadWaiter = new SegmentLoadStatusFetcher(
brokerClient,
new ObjectMapper(),
@ -86,7 +97,7 @@ public class SegmentLoadStatusFetcherTest
);
segmentLoadWaiter.waitForSegmentsToLoad();
verify(brokerClient, times(5)).sendQuery(any());
verify(brokerClient, times(5)).submitSqlTask(any());
}
@Test
@ -94,25 +105,30 @@ public class SegmentLoadStatusFetcherTest
{
brokerClient = mock(BrokerClient.class);
doReturn(mock(Request.class)).when(brokerClient).makeRequest(any(), anyString());
doAnswer(new Answer<String>()
{
when(brokerClient.submitSqlTask(any())).thenAnswer(new Answer<ListenableFuture<SqlTaskStatus>>() {
int timesInvoked = 0;
@Override
public String answer(InvocationOnMock invocation) throws Throwable
{
public ListenableFuture<SqlTaskStatus> answer(InvocationOnMock invocation) {
timesInvoked += 1;
SegmentLoadStatusFetcher.VersionLoadStatus loadStatus = new SegmentLoadStatusFetcher.VersionLoadStatus(
5,
timesInvoked,
0,
5 - timesInvoked,
0
);
return new ObjectMapper().writeValueAsString(loadStatus);
if (timesInvoked < 5) {
SqlTaskStatus status = new SqlTaskStatus(
"test-task-" + timesInvoked,
TaskState.RUNNING,
null
);
return Futures.immediateFuture(status);
} else {
SqlTaskStatus status = new SqlTaskStatus(
"test-task-" + timesInvoked,
TaskState.SUCCESS,
null
);
return Futures.immediateFuture(status);
}
}
}).when(brokerClient).sendQuery(any());
});
segmentLoadWaiter = new SegmentLoadStatusFetcher(
brokerClient,
new ObjectMapper(),
@ -123,34 +139,31 @@ public class SegmentLoadStatusFetcherTest
);
segmentLoadWaiter.waitForSegmentsToLoad();
verify(brokerClient, times(5)).sendQuery(any());
verify(brokerClient, times(5)).submitSqlTask(any());
}
@Test
public void triggerCancellationFromAnotherThread() throws Exception
{
brokerClient = mock(BrokerClient.class);
doReturn(mock(Request.class)).when(brokerClient).makeRequest(any(), anyString());
doAnswer(new Answer<String>()
{
when(brokerClient.submitSqlTask(any())).thenAnswer(new Answer<ListenableFuture<SqlTaskStatus>>() {
int timesInvoked = 0;
@Override
public String answer(InvocationOnMock invocation) throws Throwable
{
public ListenableFuture<SqlTaskStatus> answer(InvocationOnMock invocation) throws Throwable {
// sleeping broker call to simulate a long running query
Thread.sleep(1000);
timesInvoked++;
SegmentLoadStatusFetcher.VersionLoadStatus loadStatus = new SegmentLoadStatusFetcher.VersionLoadStatus(
5,
timesInvoked,
0,
5 - timesInvoked,
0
SqlTaskStatus status = new SqlTaskStatus(
"test-task-" + timesInvoked,
TaskState.RUNNING,
null
);
return new ObjectMapper().writeValueAsString(loadStatus);
return Futures.immediateFuture(status);
}
}).when(brokerClient).sendQuery(any());
});
segmentLoadWaiter = new SegmentLoadStatusFetcher(
brokerClient,
new ObjectMapper(),