diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index ab9a6382dcc..11c71c8271a 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -217,6 +217,12 @@ ${project.parent.version} runtime + + org.apache.druid.extensions + druid-testing-tools + ${project.parent.version} + runtime + org.apache.druid.extensions simple-client-sslcontext diff --git a/integration-tests/src/main/java/org/apache/druid/testing/clients/AbstractQueryResourceTestClient.java b/integration-tests/src/main/java/org/apache/druid/testing/clients/AbstractQueryResourceTestClient.java index ce2703cc9f4..907ac8d6b4b 100644 --- a/integration-tests/src/main/java/org/apache/druid/testing/clients/AbstractQueryResourceTestClient.java +++ b/integration-tests/src/main/java/org/apache/druid/testing/clients/AbstractQueryResourceTestClient.java @@ -46,6 +46,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; public abstract class AbstractQueryResourceTestClient { @@ -132,8 +133,6 @@ public abstract class AbstractQueryResourceTestClient this.acceptHeader = acceptHeader; } - public abstract String getBrokerURL(); - public List> query(String url, QueryType query) { try { @@ -154,7 +153,7 @@ public abstract class AbstractQueryResourceTestClient if (!response.getStatus().equals(HttpResponseStatus.OK)) { throw new ISE( "Error while querying[%s] status[%s] content[%s]", - getBrokerURL(), + url, response.getStatus(), new String(response.getContent(), StandardCharsets.UTF_8) ); @@ -190,4 +189,20 @@ public abstract class AbstractQueryResourceTestClient throw new RuntimeException(e); } } + + public HttpResponseStatus cancelQuery(String url, long timeoutMs) + { + try { + Request request = new Request(HttpMethod.DELETE, new URL(url)); + Future future = httpClient.go( + request, + StatusResponseHandler.getInstance() + ); + StatusResponseHolder responseHolder = future.get(timeoutMs, TimeUnit.MILLISECONDS); + return responseHolder.getStatus(); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } } diff --git a/integration-tests/src/main/java/org/apache/druid/testing/clients/QueryResourceTestClient.java b/integration-tests/src/main/java/org/apache/druid/testing/clients/QueryResourceTestClient.java index 33820fe6d70..b1045d774b3 100644 --- a/integration-tests/src/main/java/org/apache/druid/testing/clients/QueryResourceTestClient.java +++ b/integration-tests/src/main/java/org/apache/druid/testing/clients/QueryResourceTestClient.java @@ -23,7 +23,6 @@ package org.apache.druid.testing.clients; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.inject.Inject; import org.apache.druid.guice.annotations.Smile; -import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.http.client.HttpClient; import org.apache.druid.query.Query; import org.apache.druid.testing.IntegrationTestingConfig; @@ -58,15 +57,6 @@ public class QueryResourceTestClient extends AbstractQueryResourceTestClient coordinatorClient.areSegmentsLoaded(WIKIPEDIA_DATA_SOURCE), "wikipedia segment load" + ); + } + + @Test + public void testCancelValidQuery() throws Exception + { + final String queryId = "sql-cancel-test"; + final List> queryResponseFutures = new ArrayList<>(); + for (int i = 0; i < NUM_QUERIES; i++) { + queryResponseFutures.add( + sqlClient.queryAsync( + sqlHelper.getQueryURL(config.getRouterUrl()), + new SqlQuery(QUERY, null, false, ImmutableMap.of("sqlQueryId", queryId), null) + ) + ); + } + + // Wait until the sqlLifecycle is authorized and registered + Thread.sleep(1000); + final HttpResponseStatus responseStatus = sqlClient.cancelQuery( + sqlHelper.getCancelUrl(config.getRouterUrl(), queryId), + 1000 + ); + if (!responseStatus.equals(HttpResponseStatus.ACCEPTED)) { + throw new RE("Failed to cancel query [%s]", queryId); + } + + for (Future queryResponseFuture : queryResponseFutures) { + final StatusResponseHolder queryResponse = queryResponseFuture.get(1, TimeUnit.SECONDS); + if (!queryResponse.getStatus().equals(HttpResponseStatus.INTERNAL_SERVER_ERROR)) { + throw new ISE("Query is not canceled after cancel request"); + } + QueryException queryException = jsonMapper.readValue(queryResponse.getContent(), QueryException.class); + if (!QueryInterruptedException.QUERY_CANCELLED.equals(queryException.getErrorCode())) { + throw new ISE( + "Expected error code [%s], actual [%s]", + QueryInterruptedException.QUERY_CANCELLED, + queryException.getErrorCode() + ); + } + } + } + + @Test + public void testCancelInvalidQuery() throws Exception + { + final Future queryResponseFuture = sqlClient + .queryAsync( + sqlHelper.getQueryURL(config.getRouterUrl()), + new SqlQuery(QUERY, null, false, ImmutableMap.of("sqlQueryId", "validId"), null) + ); + + // Wait until the sqlLifecycle is authorized and registered + Thread.sleep(1000); + final HttpResponseStatus responseStatus = sqlClient.cancelQuery( + sqlHelper.getCancelUrl(config.getRouterUrl(), "invalidId"), + 1000 + ); + if (!responseStatus.equals(HttpResponseStatus.NOT_FOUND)) { + throw new RE("Expected http response [%s], actual response [%s]", HttpResponseStatus.NOT_FOUND, responseStatus); + } + + final StatusResponseHolder queryResponse = queryResponseFuture.get(30, TimeUnit.SECONDS); + if (!queryResponse.getStatus().equals(HttpResponseStatus.OK)) { + throw new ISE("Query is not canceled after cancel request"); + } + } +} diff --git a/server/src/main/java/org/apache/druid/server/QueryScheduler.java b/server/src/main/java/org/apache/druid/server/QueryScheduler.java index 46f2580ba54..d959e48ea4f 100644 --- a/server/src/main/java/org/apache/druid/server/QueryScheduler.java +++ b/server/src/main/java/org/apache/druid/server/QueryScheduler.java @@ -64,12 +64,26 @@ public class QueryScheduler implements QueryWatcher private final QueryPrioritizationStrategy prioritizationStrategy; private final QueryLaningStrategy laningStrategy; private final BulkheadRegistry laneRegistry; + /** - * mapping of query id to set of futures associated with the query + * mapping of query id to set of futures associated with the query. + * This map is synchronized as there are 2 threads, query execution thread and query canceling thread, + * that can access the map at the same time. + * + * The updates (additions and removals) on this and {@link #queryDatasources} are racy + * as those updates are not being done atomically on those 2 maps, + * but it is OK in most cases since they will be cleaned up once the query is done. */ private final SetMultimap> queryFutures; + /** - * mapping of query id to set of datasource names that are being queried, used for authorization + * mapping of query id to set of datasource names that are being queried, used for authorization. + * This map is synchronized as there are 2 threads, query execution thread and query canceling thread, + * that can access the map at the same time. + * + * The updates (additions and removals) on this and {@link #queryFutures} are racy + * as those updates are not being done atomically on those 2 maps, + * but it is OK in most cases since they will be cleaned up once the query is done. */ private final SetMultimap queryDatasources; diff --git a/services/src/main/java/org/apache/druid/server/AsyncQueryForwardingServlet.java b/services/src/main/java/org/apache/druid/server/AsyncQueryForwardingServlet.java index 0a2da6293a9..6779c06f596 100644 --- a/services/src/main/java/org/apache/druid/server/AsyncQueryForwardingServlet.java +++ b/services/src/main/java/org/apache/druid/server/AsyncQueryForwardingServlet.java @@ -228,7 +228,7 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu targetServer = hostFinder.findServerAvatica(connectionId); byte[] requestBytes = objectMapper.writeValueAsBytes(requestMap); request.setAttribute(AVATICA_QUERY_ATTRIBUTE, requestBytes); - } else if (isNativeQueryEndpoint && HttpMethod.DELETE.is(method)) { + } else if (HttpMethod.DELETE.is(method)) { // query cancellation request targetServer = hostFinder.pickDefaultServer(); broadcastQueryCancelRequest(request, targetServer); @@ -285,8 +285,6 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu */ private void broadcastQueryCancelRequest(HttpServletRequest request, Server targetServer) { - // send query cancellation to all brokers this query may have gone to - // to keep the code simple, the proxy servlet will also send a request to the default targetServer. for (final Server server : hostFinder.getAllServers()) { if (server.getHost().equals(targetServer.getHost())) { continue; diff --git a/sql/src/main/java/org/apache/druid/sql/SqlLifecycle.java b/sql/src/main/java/org/apache/druid/sql/SqlLifecycle.java index c30aba0fa23..98a7c22abb7 100644 --- a/sql/src/main/java/org/apache/druid/sql/SqlLifecycle.java +++ b/sql/src/main/java/org/apache/druid/sql/SqlLifecycle.java @@ -24,7 +24,6 @@ import com.google.common.base.Preconditions; import com.google.common.collect.Iterables; import com.google.errorprone.annotations.concurrent.GuardedBy; import org.apache.calcite.avatica.remote.TypedValue; -import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.tools.RelConversionException; import org.apache.calcite.tools.ValidationException; @@ -40,6 +39,7 @@ import org.apache.druid.java.util.emitter.service.ServiceMetricEvent; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryTimeoutException; +import org.apache.druid.server.QueryScheduler; import org.apache.druid.server.QueryStats; import org.apache.druid.server.RequestLogLine; import org.apache.druid.server.log.RequestLogger; @@ -64,7 +64,9 @@ import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.UUID; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -77,11 +79,10 @@ import java.util.stream.Collectors; *
  • Validation and Authorization ({@link #validateAndAuthorize(HttpServletRequest)} or {@link #validateAndAuthorize(AuthenticationResult)})
  • *
  • Planning ({@link #plan()})
  • *
  • Execution ({@link #execute()})
  • - *
  • Logging ({@link #emitLogsAndMetrics(Throwable, String, long)})
  • + *
  • Logging ({@link #finalizeStateAndEmitLogsAndMetrics(Throwable, String, long)})
  • * * - *

    Unlike QueryLifecycle, this class is designed to be thread safe so that it can be used in multi-threaded - * scenario (JDBC) without external synchronization. + * Every method in this class must be called by the same thread except for {@link #cancel()}. */ public class SqlLifecycle { @@ -90,34 +91,33 @@ public class SqlLifecycle private final PlannerFactory plannerFactory; private final ServiceEmitter emitter; private final RequestLogger requestLogger; + private final QueryScheduler queryScheduler; private final long startMs; private final long startNs; - private final Object lock = new Object(); - @GuardedBy("lock") + /** + * This lock coordinates the access to {@link #state} as there is a happens-before relationship + * between {@link #cancel} and {@link #transition}. + */ + private final Object stateLock = new Object(); + @GuardedBy("stateLock") private State state = State.NEW; // init during intialize - @GuardedBy("lock") private String sql; - @GuardedBy("lock") private Map queryContext; - @GuardedBy("lock") private List parameters; // init during plan - @GuardedBy("lock") private PlannerContext plannerContext; - @GuardedBy("lock") private ValidationResult validationResult; - @GuardedBy("lock") private PrepareResult prepareResult; - @GuardedBy("lock") private PlannerResult plannerResult; public SqlLifecycle( PlannerFactory plannerFactory, ServiceEmitter emitter, RequestLogger requestLogger, + QueryScheduler queryScheduler, long startMs, long startNs ) @@ -125,6 +125,7 @@ public class SqlLifecycle this.plannerFactory = plannerFactory; this.emitter = emitter; this.requestLogger = requestLogger; + this.queryScheduler = queryScheduler; this.startMs = startMs; this.startNs = startNs; this.parameters = Collections.emptyList(); @@ -137,15 +138,12 @@ public class SqlLifecycle */ public String initialize(String sql, Map queryContext) { - synchronized (lock) { - transition(State.NEW, State.INITIALIZED); - this.sql = sql; - this.queryContext = contextWithSqlId(queryContext); - return sqlQueryId(); - } + transition(State.NEW, State.INITIALIZED); + this.sql = sql; + this.queryContext = contextWithSqlId(queryContext); + return sqlQueryId(); } - @GuardedBy("lock") private Map contextWithSqlId(Map queryContext) { Map newContext = new HashMap<>(); @@ -161,7 +159,6 @@ public class SqlLifecycle return newContext; } - @GuardedBy("lock") private String sqlQueryId() { return (String) this.queryContext.get(PlannerContext.CTX_SQL_QUERY_ID); @@ -173,11 +170,9 @@ public class SqlLifecycle */ public void setParameters(List parameters) { - synchronized (lock) { - this.parameters = parameters; - if (this.plannerContext != null) { - this.plannerContext.setParameters(parameters); - } + this.parameters = parameters; + if (this.plannerContext != null) { + this.plannerContext.setParameters(parameters); } } @@ -189,21 +184,21 @@ public class SqlLifecycle */ public void validateAndAuthorize(AuthenticationResult authenticationResult) { - synchronized (lock) { + synchronized (stateLock) { if (state == State.AUTHORIZED) { return; } - transition(State.INITIALIZED, State.AUTHORIZING); - validate(authenticationResult); - Access access = doAuthorize( - AuthorizationUtils.authorizeAllResourceActions( - authenticationResult, - Iterables.transform(validationResult.getResources(), AuthorizationUtils.RESOURCE_READ_RA_GENERATOR), - plannerFactory.getAuthorizerMapper() - ) - ); - checkAccess(access); } + transition(State.INITIALIZED, State.AUTHORIZING); + validate(authenticationResult); + Access access = doAuthorize( + AuthorizationUtils.authorizeAllResourceActions( + authenticationResult, + Iterables.transform(validationResult.getResources(), AuthorizationUtils.RESOURCE_READ_RA_GENERATOR), + plannerFactory.getAuthorizerMapper() + ) + ); + checkAccess(access); } /** @@ -215,22 +210,19 @@ public class SqlLifecycle */ public void validateAndAuthorize(HttpServletRequest req) { - synchronized (lock) { - transition(State.INITIALIZED, State.AUTHORIZING); - AuthenticationResult authResult = AuthorizationUtils.authenticationResultFromRequest(req); - validate(authResult); - Access access = doAuthorize( - AuthorizationUtils.authorizeAllResourceActions( - req, - Iterables.transform(validationResult.getResources(), AuthorizationUtils.RESOURCE_READ_RA_GENERATOR), - plannerFactory.getAuthorizerMapper() - ) - ); - checkAccess(access); - } + transition(State.INITIALIZED, State.AUTHORIZING); + AuthenticationResult authResult = AuthorizationUtils.authenticationResultFromRequest(req); + validate(authResult); + Access access = doAuthorize( + AuthorizationUtils.authorizeAllResourceActions( + req, + Iterables.transform(validationResult.getResources(), AuthorizationUtils.RESOURCE_READ_RA_GENERATOR), + plannerFactory.getAuthorizerMapper() + ) + ); + checkAccess(access); } - @GuardedBy("lock") private ValidationResult validate(AuthenticationResult authenticationResult) { try (DruidPlanner planner = plannerFactory.createPlanner(queryContext)) { @@ -251,7 +243,6 @@ public class SqlLifecycle } } - @GuardedBy("lock") private Access doAuthorize(final Access authorizationResult) { if (!authorizationResult.isAllowed()) { @@ -263,7 +254,6 @@ public class SqlLifecycle return authorizationResult; } - @GuardedBy("lock") private void checkAccess(Access access) { plannerContext.setAuthorizationResult(access); @@ -280,22 +270,22 @@ public class SqlLifecycle */ public PrepareResult prepare() throws RelConversionException { - synchronized (lock) { + synchronized (stateLock) { if (state != State.AUTHORIZED) { throw new ISE("Cannot prepare because current state[%s] is not [%s].", state, State.AUTHORIZED); } - Preconditions.checkNotNull(plannerContext, "Cannot prepare, plannerContext is null"); - try (DruidPlanner planner = plannerFactory.createPlannerWithContext(plannerContext)) { - this.prepareResult = planner.prepare(sql); - return prepareResult; - } - // we can't collapse catch clauses since SqlPlanningException has type-sensitive constructors. - catch (SqlParseException e) { - throw new SqlPlanningException(e); - } - catch (ValidationException e) { - throw new SqlPlanningException(e); - } + } + Preconditions.checkNotNull(plannerContext, "Cannot prepare, plannerContext is null"); + try (DruidPlanner planner = plannerFactory.createPlannerWithContext(plannerContext)) { + this.prepareResult = planner.prepare(sql); + return prepareResult; + } + // we can't collapse catch clauses since SqlPlanningException has type-sensitive constructors. + catch (SqlParseException e) { + throw new SqlPlanningException(e); + } + catch (ValidationException e) { + throw new SqlPlanningException(e); } } @@ -304,23 +294,37 @@ public class SqlLifecycle * * If successful, the lifecycle will first transition from {@link State#AUTHORIZED} to {@link State#PLANNED}. */ - public PlannerContext plan() throws RelConversionException + public void plan() throws RelConversionException { - synchronized (lock) { - transition(State.AUTHORIZED, State.PLANNED); - Preconditions.checkNotNull(plannerContext, "Cannot plan, plannerContext is null"); - try (DruidPlanner planner = plannerFactory.createPlannerWithContext(plannerContext)) { - this.plannerResult = planner.plan(sql); - } - // we can't collapse catch clauses since SqlPlanningException has type-sensitive constructors. - catch (SqlParseException e) { - throw new SqlPlanningException(e); - } - catch (ValidationException e) { - throw new SqlPlanningException(e); - } - return plannerContext; + transition(State.AUTHORIZED, State.PLANNED); + Preconditions.checkNotNull(plannerContext, "Cannot plan, plannerContext is null"); + try (DruidPlanner planner = plannerFactory.createPlannerWithContext(plannerContext)) { + this.plannerResult = planner.plan(sql); } + // we can't collapse catch clauses since SqlPlanningException has type-sensitive constructors. + catch (SqlParseException e) { + throw new SqlPlanningException(e); + } + catch (ValidationException e) { + throw new SqlPlanningException(e); + } + } + + /** + * This method must be called after {@link #plan()}. + */ + public SqlRowTransformer createRowTransformer() + { + assert plannerContext != null; + assert plannerResult != null; + + return new SqlRowTransformer(plannerContext.getTimeZone(), plannerResult.rowType()); + } + + @VisibleForTesting + PlannerContext getPlannerContext() + { + return plannerContext; } /** @@ -330,10 +334,8 @@ public class SqlLifecycle */ public Sequence execute() { - synchronized (lock) { - transition(State.PLANNED, State.EXECUTING); - return plannerResult.run(); - } + transition(State.PLANNED, State.EXECUTING); + return plannerResult.run(); } @VisibleForTesting @@ -354,7 +356,9 @@ public class SqlLifecycle result = execute(); } catch (Throwable e) { - emitLogsAndMetrics(e, null, -1); + if (!(e instanceof ForbiddenException)) { + finalizeStateAndEmitLogsAndMetrics(e, null, -1); + } throw e; } @@ -363,7 +367,7 @@ public class SqlLifecycle @Override public void after(boolean isDone, Throwable thrown) { - emitLogsAndMetrics(thrown, null, -1); + finalizeStateAndEmitLogsAndMetrics(thrown, null, -1); } }); } @@ -372,15 +376,34 @@ public class SqlLifecycle @VisibleForTesting public ValidationResult runAnalyzeResources(AuthenticationResult authenticationResult) { - synchronized (lock) { - return validate(authenticationResult); - } + return validate(authenticationResult); } - public RelDataType rowType() + public Set getAuthorizedResources() { - synchronized (lock) { - return plannerResult != null ? plannerResult.rowType() : prepareResult.getRowType(); + assert validationResult != null; + return validationResult.getResources(); + } + + /** + * Cancel all native queries associated to this lifecycle. + * + * This method is thread-safe. + */ + public void cancel() + { + synchronized (stateLock) { + if (state == State.CANCELLED) { + return; + } + state = State.CANCELLED; + } + + final CopyOnWriteArrayList nativeQueryIds = plannerContext.getNativeQueryIds(); + + for (String nativeQueryId : nativeQueryIds) { + log.debug("canceling native query [%s]", nativeQueryId); + queryScheduler.cancelQuery(nativeQueryId); } } @@ -391,104 +414,121 @@ public class SqlLifecycle * @param remoteAddress remote address, for logging; or null if unknown * @param bytesWritten number of bytes written; will become a query/bytes metric if >= 0 */ - public void emitLogsAndMetrics( + public void finalizeStateAndEmitLogsAndMetrics( @Nullable final Throwable e, @Nullable final String remoteAddress, final long bytesWritten ) { - synchronized (lock) { - if (sql == null) { - // Never initialized, don't log or emit anything. - return; + if (sql == null) { + // Never initialized, don't log or emit anything. + return; + } + + synchronized (stateLock) { + assert state != State.UNAUTHORIZED; // should not emit below metrics when the query fails to authorize + + if (state != State.CANCELLED) { + if (state == State.DONE) { + log.warn("Tried to emit logs and metrics twice for query[%s]!", sqlQueryId()); + } + + state = State.DONE; } + } - if (state == State.DONE) { - log.warn("Tried to emit logs and metrics twice for query[%s]!", sqlQueryId()); + final boolean success = e == null; + final long queryTimeNs = System.nanoTime() - startNs; + + try { + ServiceMetricEvent.Builder metricBuilder = ServiceMetricEvent.builder(); + if (plannerContext != null) { + metricBuilder.setDimension("id", plannerContext.getSqlQueryId()); + metricBuilder.setDimension("nativeQueryIds", plannerContext.getNativeQueryIds().toString()); } - - state = State.DONE; - - final boolean success = e == null; - final long queryTimeNs = System.nanoTime() - startNs; - - try { - ServiceMetricEvent.Builder metricBuilder = ServiceMetricEvent.builder(); - if (plannerContext != null) { - metricBuilder.setDimension("id", plannerContext.getSqlQueryId()); - metricBuilder.setDimension("nativeQueryIds", plannerContext.getNativeQueryIds().toString()); - } - if (validationResult != null) { - metricBuilder.setDimension( - "dataSource", - validationResult.getResources().stream().map(Resource::getName).collect(Collectors.toList()).toString() - ); - } - metricBuilder.setDimension("remoteAddress", StringUtils.nullToEmptyNonDruidDataString(remoteAddress)); - metricBuilder.setDimension("success", String.valueOf(success)); - emitter.emit(metricBuilder.build("sqlQuery/time", TimeUnit.NANOSECONDS.toMillis(queryTimeNs))); - if (bytesWritten >= 0) { - emitter.emit(metricBuilder.build("sqlQuery/bytes", bytesWritten)); - } - - final Map statsMap = new LinkedHashMap<>(); - statsMap.put("sqlQuery/time", TimeUnit.NANOSECONDS.toMillis(queryTimeNs)); - statsMap.put("sqlQuery/bytes", bytesWritten); - statsMap.put("success", success); - statsMap.put("context", queryContext); - if (plannerContext != null) { - statsMap.put("identity", plannerContext.getAuthenticationResult().getIdentity()); - queryContext.put("nativeQueryIds", plannerContext.getNativeQueryIds().toString()); - } - if (e != null) { - statsMap.put("exception", e.toString()); - - if (e instanceof QueryInterruptedException || e instanceof QueryTimeoutException) { - statsMap.put("interrupted", true); - statsMap.put("reason", e.toString()); - } - } - - requestLogger.logSqlQuery( - RequestLogLine.forSql( - sql, - queryContext, - DateTimes.utc(startMs), - remoteAddress, - new QueryStats(statsMap) - ) + if (validationResult != null) { + metricBuilder.setDimension( + "dataSource", + validationResult.getResources().stream().map(Resource::getName).collect(Collectors.toList()).toString() ); } - catch (Exception ex) { - log.error(ex, "Unable to log SQL [%s]!", sql); + metricBuilder.setDimension("remoteAddress", StringUtils.nullToEmptyNonDruidDataString(remoteAddress)); + metricBuilder.setDimension("success", String.valueOf(success)); + emitter.emit(metricBuilder.build("sqlQuery/time", TimeUnit.NANOSECONDS.toMillis(queryTimeNs))); + if (bytesWritten >= 0) { + emitter.emit(metricBuilder.build("sqlQuery/bytes", bytesWritten)); } + + final Map statsMap = new LinkedHashMap<>(); + statsMap.put("sqlQuery/time", TimeUnit.NANOSECONDS.toMillis(queryTimeNs)); + statsMap.put("sqlQuery/bytes", bytesWritten); + statsMap.put("success", success); + statsMap.put("context", queryContext); + if (plannerContext != null) { + statsMap.put("identity", plannerContext.getAuthenticationResult().getIdentity()); + queryContext.put("nativeQueryIds", plannerContext.getNativeQueryIds().toString()); + } + if (e != null) { + statsMap.put("exception", e.toString()); + + if (e instanceof QueryInterruptedException || e instanceof QueryTimeoutException) { + statsMap.put("interrupted", true); + statsMap.put("reason", e.toString()); + } + } + + requestLogger.logSqlQuery( + RequestLogLine.forSql( + sql, + queryContext, + DateTimes.utc(startMs), + remoteAddress, + new QueryStats(statsMap) + ) + ); + } + catch (Exception ex) { + log.error(ex, "Unable to log SQL [%s]!", sql); } } @VisibleForTesting public State getState() { - synchronized (lock) { + synchronized (stateLock) { return state; } } @VisibleForTesting - public Map getQueryContext() + Map getQueryContext() { - synchronized (lock) { - return queryContext; - } + return queryContext; } - @GuardedBy("lock") private void transition(final State from, final State to) { - if (state != from) { - throw new ISE("Cannot transition from[%s] to[%s] because current state[%s] is not [%s].", from, to, state, from); - } + synchronized (stateLock) { + if (state == State.CANCELLED) { + throw new QueryInterruptedException( + QueryInterruptedException.QUERY_CANCELLED, + StringUtils.format("Query is canceled [%s]", sqlQueryId()), + null, + null + ); + } + if (state != from) { + throw new ISE( + "Cannot transition from[%s] to[%s] because current state[%s] is not [%s].", + from, + to, + state, + from + ); + } - state = to; + state = to; + } } enum State @@ -499,7 +539,10 @@ public class SqlLifecycle AUTHORIZED, PLANNED, EXECUTING, + + // final states UNAUTHORIZED, - DONE + CANCELLED, // query is cancelled. can be transitioned to this state only after AUTHORIZED. + DONE // query could either succeed or fail } } diff --git a/sql/src/main/java/org/apache/druid/sql/SqlLifecycleFactory.java b/sql/src/main/java/org/apache/druid/sql/SqlLifecycleFactory.java index 250789481b8..948492d64b6 100644 --- a/sql/src/main/java/org/apache/druid/sql/SqlLifecycleFactory.java +++ b/sql/src/main/java/org/apache/druid/sql/SqlLifecycleFactory.java @@ -22,6 +22,7 @@ package org.apache.druid.sql; import com.google.inject.Inject; import org.apache.druid.guice.LazySingleton; import org.apache.druid.java.util.emitter.service.ServiceEmitter; +import org.apache.druid.server.QueryScheduler; import org.apache.druid.server.log.RequestLogger; import org.apache.druid.sql.calcite.planner.PlannerFactory; @@ -31,17 +32,20 @@ public class SqlLifecycleFactory private final PlannerFactory plannerFactory; private final ServiceEmitter emitter; private final RequestLogger requestLogger; + private final QueryScheduler queryScheduler; @Inject public SqlLifecycleFactory( PlannerFactory plannerFactory, ServiceEmitter emitter, - RequestLogger requestLogger + RequestLogger requestLogger, + QueryScheduler queryScheduler ) { this.plannerFactory = plannerFactory; this.emitter = emitter; this.requestLogger = requestLogger; + this.queryScheduler = queryScheduler; } public SqlLifecycle factorize() @@ -50,6 +54,7 @@ public class SqlLifecycleFactory plannerFactory, emitter, requestLogger, + queryScheduler, System.currentTimeMillis(), System.nanoTime() ); diff --git a/sql/src/main/java/org/apache/druid/sql/SqlLifecycleManager.java b/sql/src/main/java/org/apache/druid/sql/SqlLifecycleManager.java new file mode 100644 index 00000000000..8b222eb569a --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/SqlLifecycleManager.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql; + +import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import org.apache.druid.guice.LazySingleton; +import org.apache.druid.sql.SqlLifecycle.State; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * This class manages only _authorized_ {@link SqlLifecycle}s submitted via HTTP, + * such as {@link org.apache.druid.sql.http.SqlResource}. The main use case of this class is + * tracking running queries so that the cancel API can identify the lifecycles to cancel. + * + * This class is thread-safe as there are 2 or more threads that can access lifecycles at the same time + * for query running or query canceling. + * + * For managing and canceling native queries, see {@link org.apache.druid.server.QueryScheduler}. + * As its name indicates, it also performs resource scheduling for native queries based on query lanes + * {@link org.apache.druid.server.QueryLaningStrategy}. + * + * @see org.apache.druid.server.QueryScheduler#cancelQuery(String) + */ +@LazySingleton +public class SqlLifecycleManager +{ + private final Object lock = new Object(); + + @GuardedBy("lock") + private final Map> sqlLifecycles = new HashMap<>(); + + public void add(String sqlQueryId, SqlLifecycle lifecycle) + { + synchronized (lock) { + assert lifecycle.getState() == State.AUTHORIZED; + sqlLifecycles.computeIfAbsent(sqlQueryId, k -> new ArrayList<>()) + .add(lifecycle); + } + } + + /** + * Removes the given lifecycle of the given query ID. + * This method uses {@link Object#equals} to find the lifecycle matched to the given parameter. + */ + public void remove(String sqlQueryId, SqlLifecycle lifecycle) + { + synchronized (lock) { + List lifecycles = sqlLifecycles.get(sqlQueryId); + if (lifecycles != null) { + lifecycles.remove(lifecycle); + if (lifecycles.isEmpty()) { + sqlLifecycles.remove(sqlQueryId); + } + } + } + } + + /** + * For the given sqlQueryId, this method removes all lifecycles that match to the given list of lifecycles. + * This method uses {@link Object#equals} for matching lifecycles. + */ + public void removeAll(String sqlQueryId, List lifecyclesToRemove) + { + synchronized (lock) { + List lifecycles = sqlLifecycles.get(sqlQueryId); + if (lifecycles != null) { + lifecycles.removeAll(lifecyclesToRemove); + if (lifecycles.isEmpty()) { + sqlLifecycles.remove(sqlQueryId); + } + } + } + } + + /** + * Returns a snapshot of the lifecycles for the given sqlQueryId. + */ + public List getAll(String sqlQueryId) + { + synchronized (lock) { + List lifecycles = sqlLifecycles.get(sqlQueryId); + return lifecycles == null ? Collections.emptyList() : ImmutableList.copyOf(lifecycles); + } + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/SqlRowTransformer.java b/sql/src/main/java/org/apache/druid/sql/SqlRowTransformer.java new file mode 100644 index 00000000000..5570c42fd83 --- /dev/null +++ b/sql/src/main/java/org/apache/druid/sql/SqlRowTransformer.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.sql.calcite.planner.Calcites; +import org.joda.time.DateTimeZone; +import org.joda.time.format.ISODateTimeFormat; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.List; + +/** + * This class transforms the values of TIMESTAMP or DATE type for sql query results. + * The transformation is required only when the sql query is submitted to {@link org.apache.druid.sql.http.SqlResource}. + */ +public class SqlRowTransformer +{ + private final DateTimeZone timeZone; + private final List fieldList; + + // Remember which columns are time-typed, so we can emit ISO8601 instead of millis values. + private final boolean[] timeColumns; + private final boolean[] dateColumns; + + SqlRowTransformer(DateTimeZone timeZone, RelDataType rowType) + { + this.timeZone = timeZone; + this.fieldList = new ArrayList<>(rowType.getFieldCount()); + this.timeColumns = new boolean[rowType.getFieldCount()]; + this.dateColumns = new boolean[rowType.getFieldCount()]; + for (int i = 0; i < rowType.getFieldCount(); i++) { + final SqlTypeName sqlTypeName = rowType.getFieldList().get(i).getType().getSqlTypeName(); + timeColumns[i] = sqlTypeName == SqlTypeName.TIMESTAMP; + dateColumns[i] = sqlTypeName == SqlTypeName.DATE; + fieldList.add(rowType.getFieldList().get(i).getName()); + } + } + + public List getFieldList() + { + return fieldList; + } + + @Nullable + public Object transform(Object[] row, int i) + { + if (row[i] == null) { + return null; + } else if (timeColumns[i]) { + return ISODateTimeFormat.dateTime().print( + Calcites.calciteTimestampToJoda((long) row[i], timeZone) + ); + } else if (dateColumns[i]) { + return ISODateTimeFormat.dateTime().print( + Calcites.calciteDateToJoda((int) row[i], timeZone) + ); + } else { + return row[i]; + } + } +} diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidStatement.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidStatement.java index ff4f2681016..560116c0be0 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidStatement.java @@ -58,6 +58,7 @@ public class DruidStatement implements Closeable private final String connectionId; private final int statementId; private final Map queryContext; + @GuardedBy("lock") private final SqlLifecycle sqlLifecycle; private final Runnable onClose; private final Object lock = new Object(); @@ -261,14 +262,6 @@ public class DruidStatement implements Closeable } } - public RelDataType getRowType() - { - synchronized (lock) { - ensure(State.PREPARED, State.RUNNING, State.DONE); - return sqlLifecycle.rowType(); - } - } - public long getCurrentOffset() { synchronized (lock) { @@ -348,7 +341,9 @@ public class DruidStatement implements Closeable // First close. Run the onClose function. try { onClose.run(); - sqlLifecycle.emitLogsAndMetrics(t, null, -1); + synchronized (lock) { + sqlLifecycle.finalizeStateAndEmitLogsAndMetrics(t, null, -1); + } } catch (Throwable t1) { t.addSuppressed(t1); @@ -362,7 +357,9 @@ public class DruidStatement implements Closeable // First close. Run the onClose function. try { if (!(this.throwable instanceof ForbiddenException)) { - sqlLifecycle.emitLogsAndMetrics(this.throwable, null, -1); + synchronized (lock) { + sqlLifecycle.finalizeStateAndEmitLogsAndMetrics(this.throwable, null, -1); + } } onClose.run(); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java index 0d7f979d42f..59d4bd8f019 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/PlannerContext.java @@ -71,7 +71,7 @@ public class PlannerContext private final Map queryContext; private final String sqlQueryId; private final boolean stringifyArrays; - private final List nativeQueryIds = new CopyOnWriteArrayList<>(); + private final CopyOnWriteArrayList nativeQueryIds = new CopyOnWriteArrayList<>(); // bindings for dynamic parameters to bind during planning private List parameters = Collections.emptyList(); // result of authentication, providing identity to authorize set of resources produced by validation @@ -204,7 +204,7 @@ public class PlannerContext return sqlQueryId; } - public List getNativeQueryIds() + public CopyOnWriteArrayList getNativeQueryIds() { return nativeQueryIds; } diff --git a/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java b/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java index b880209da4a..232cf5b22a2 100644 --- a/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java +++ b/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java @@ -22,14 +22,14 @@ package org.apache.druid.sql.http; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; import com.google.common.io.CountingOutputStream; import com.google.inject.Inject; import org.apache.calcite.plan.RelOptPlanner; -import org.apache.calcite.rel.type.RelDataTypeField; -import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.guice.annotations.Json; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.java.util.common.logger.Logger; @@ -39,19 +39,24 @@ import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryTimeoutException; import org.apache.druid.query.QueryUnsupportedException; import org.apache.druid.query.ResourceLimitExceededException; +import org.apache.druid.server.security.Access; +import org.apache.druid.server.security.AuthorizationUtils; +import org.apache.druid.server.security.AuthorizerMapper; import org.apache.druid.server.security.ForbiddenException; +import org.apache.druid.server.security.Resource; import org.apache.druid.sql.SqlLifecycle; import org.apache.druid.sql.SqlLifecycleFactory; +import org.apache.druid.sql.SqlLifecycleManager; import org.apache.druid.sql.SqlPlanningException; -import org.apache.druid.sql.calcite.planner.Calcites; -import org.apache.druid.sql.calcite.planner.PlannerContext; -import org.joda.time.DateTimeZone; -import org.joda.time.format.ISODateTimeFormat; +import org.apache.druid.sql.SqlRowTransformer; +import javax.annotation.Nullable; import javax.servlet.http.HttpServletRequest; import javax.ws.rs.Consumes; +import javax.ws.rs.DELETE; import javax.ws.rs.POST; import javax.ws.rs.Path; +import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; @@ -59,8 +64,9 @@ import javax.ws.rs.core.Response; import javax.ws.rs.core.Response.Status; import javax.ws.rs.core.StreamingOutput; import java.io.IOException; -import java.util.Arrays; import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; @Path("/druid/v2/sql/") public class SqlResource @@ -68,16 +74,22 @@ public class SqlResource private static final Logger log = new Logger(SqlResource.class); private final ObjectMapper jsonMapper; + private final AuthorizerMapper authorizerMapper; private final SqlLifecycleFactory sqlLifecycleFactory; + private final SqlLifecycleManager sqlLifecycleManager; @Inject public SqlResource( @Json ObjectMapper jsonMapper, - SqlLifecycleFactory sqlLifecycleFactory + AuthorizerMapper authorizerMapper, + SqlLifecycleFactory sqlLifecycleFactory, + SqlLifecycleManager sqlLifecycleManager ) { this.jsonMapper = Preconditions.checkNotNull(jsonMapper, "jsonMapper"); + this.authorizerMapper = Preconditions.checkNotNull(authorizerMapper, "authorizerMapper"); this.sqlLifecycleFactory = Preconditions.checkNotNull(sqlLifecycleFactory, "sqlLifecycleFactory"); + this.sqlLifecycleManager = Preconditions.checkNotNull(sqlLifecycleManager, "sqlLifecycleManager"); } @POST @@ -98,24 +110,14 @@ public class SqlResource lifecycle.setParameters(sqlQuery.getParameterList()); lifecycle.validateAndAuthorize(req); - final PlannerContext plannerContext = lifecycle.plan(); - final DateTimeZone timeZone = plannerContext.getTimeZone(); + // must add after lifecycle is authorized + sqlLifecycleManager.add(sqlQueryId, lifecycle); - // Remember which columns are time-typed, so we can emit ISO8601 instead of millis values. - // Also store list of all column names, for X-Druid-Sql-Columns header. - final List fieldList = lifecycle.rowType().getFieldList(); - final boolean[] timeColumns = new boolean[fieldList.size()]; - final boolean[] dateColumns = new boolean[fieldList.size()]; - final String[] columnNames = new String[fieldList.size()]; + lifecycle.plan(); - for (int i = 0; i < fieldList.size(); i++) { - final SqlTypeName sqlTypeName = fieldList.get(i).getType().getSqlTypeName(); - timeColumns[i] = sqlTypeName == SqlTypeName.TIMESTAMP; - dateColumns[i] = sqlTypeName == SqlTypeName.DATE; - columnNames[i] = fieldList.get(i).getName(); - } - - final Yielder yielder0 = Yielders.each(lifecycle.execute()); + final SqlRowTransformer rowTransformer = lifecycle.createRowTransformer(); + final Sequence sequence = lifecycle.execute(); + final Yielder yielder0 = Yielders.each(sequence); try { return Response @@ -130,30 +132,15 @@ public class SqlResource writer.writeResponseStart(); if (sqlQuery.includeHeader()) { - writer.writeHeader(Arrays.asList(columnNames)); + writer.writeHeader(rowTransformer.getFieldList()); } while (!yielder.isDone()) { final Object[] row = yielder.get(); writer.writeRowStart(); - for (int i = 0; i < fieldList.size(); i++) { - final Object value; - - if (row[i] == null) { - value = null; - } else if (timeColumns[i]) { - value = ISODateTimeFormat.dateTime().print( - Calcites.calciteTimestampToJoda((long) row[i], timeZone) - ); - } else if (dateColumns[i]) { - value = ISODateTimeFormat.dateTime().print( - Calcites.calciteDateToJoda((int) row[i], timeZone) - ); - } else { - value = row[i]; - } - - writer.writeRowField(fieldList.get(i).getName(), value); + for (int i = 0; i < rowTransformer.getFieldList().size(); i++) { + final Object value = rowTransformer.transform(row, i); + writer.writeRowField(rowTransformer.getFieldList().get(i), value); } writer.writeRowEnd(); yielder = yielder.next(null); @@ -168,7 +155,7 @@ public class SqlResource } finally { yielder.close(); - lifecycle.emitLogsAndMetrics(e, remoteAddr, os.getCount()); + endLifecycle(sqlQueryId, lifecycle, e, remoteAddr, os.getCount()); } } ) @@ -182,27 +169,28 @@ public class SqlResource } } catch (QueryCapacityExceededException cap) { - lifecycle.emitLogsAndMetrics(cap, remoteAddr, -1); + endLifecycle(sqlQueryId, lifecycle, cap, remoteAddr, -1); return buildNonOkResponse(QueryCapacityExceededException.STATUS_CODE, cap); } catch (QueryUnsupportedException unsupported) { - lifecycle.emitLogsAndMetrics(unsupported, remoteAddr, -1); + endLifecycle(sqlQueryId, lifecycle, unsupported, remoteAddr, -1); return buildNonOkResponse(QueryUnsupportedException.STATUS_CODE, unsupported); } catch (QueryTimeoutException timeout) { - lifecycle.emitLogsAndMetrics(timeout, remoteAddr, -1); + endLifecycle(sqlQueryId, lifecycle, timeout, remoteAddr, -1); return buildNonOkResponse(QueryTimeoutException.STATUS_CODE, timeout); } catch (SqlPlanningException | ResourceLimitExceededException e) { - lifecycle.emitLogsAndMetrics(e, remoteAddr, -1); + endLifecycle(sqlQueryId, lifecycle, e, remoteAddr, -1); return buildNonOkResponse(BadQueryException.STATUS_CODE, e); } catch (ForbiddenException e) { + endLifecycleWithoutEmittingMetrics(sqlQueryId, lifecycle); throw e; // let ForbiddenExceptionMapper handle this } catch (Exception e) { log.warn(e, "Failed to handle query: %s", sqlQuery); - lifecycle.emitLogsAndMetrics(e, remoteAddr, -1); + endLifecycle(sqlQueryId, lifecycle, e, remoteAddr, -1); final Exception exceptionToReport; @@ -222,11 +210,66 @@ public class SqlResource } } - Response buildNonOkResponse(int status, Exception e) throws JsonProcessingException + private void endLifecycleWithoutEmittingMetrics( + String sqlQueryId, + SqlLifecycle lifecycle + ) + { + sqlLifecycleManager.remove(sqlQueryId, lifecycle); + } + + private void endLifecycle( + String sqlQueryId, + SqlLifecycle lifecycle, + @Nullable final Throwable e, + @Nullable final String remoteAddress, + final long bytesWritten + ) + { + lifecycle.finalizeStateAndEmitLogsAndMetrics(e, remoteAddress, bytesWritten); + sqlLifecycleManager.remove(sqlQueryId, lifecycle); + } + + private Response buildNonOkResponse(int status, Exception e) throws JsonProcessingException { return Response.status(status) .type(MediaType.APPLICATION_JSON_TYPE) .entity(jsonMapper.writeValueAsBytes(e)) .build(); } + + @DELETE + @Path("{id}") + @Produces(MediaType.APPLICATION_JSON) + public Response cancelQuery( + @PathParam("id") String sqlQueryId, + @Context final HttpServletRequest req + ) + { + log.debug("Received cancel request for query [%s]", sqlQueryId); + + List lifecycles = sqlLifecycleManager.getAll(sqlQueryId); + if (lifecycles.isEmpty()) { + return Response.status(Status.NOT_FOUND).build(); + } + Set resources = lifecycles + .stream() + .flatMap(lifecycle -> lifecycle.getAuthorizedResources().stream()) + .collect(Collectors.toSet()); + Access access = AuthorizationUtils.authorizeAllResourceActions( + req, + Iterables.transform(resources, AuthorizationUtils.RESOURCE_READ_RA_GENERATOR), + authorizerMapper + ); + + if (access.isAllowed()) { + // should remove only the lifecycles in the snapshot. + sqlLifecycleManager.removeAll(sqlQueryId, lifecycles); + lifecycles.forEach(SqlLifecycle::cancel); + return Response.status(Status.ACCEPTED).build(); + } else { + // Return 404 for authorization failures as well + return Response.status(Status.NOT_FOUND).build(); + } + } } diff --git a/sql/src/test/java/org/apache/druid/sql/SqlLifecycleManagerTest.java b/sql/src/test/java/org/apache/druid/sql/SqlLifecycleManagerTest.java new file mode 100644 index 00000000000..8ddef9fc8d5 --- /dev/null +++ b/sql/src/test/java/org/apache/druid/sql/SqlLifecycleManagerTest.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql; + +import com.google.common.collect.ImmutableList; +import org.apache.druid.sql.SqlLifecycle.State; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; + +import java.util.List; + +public class SqlLifecycleManagerTest +{ + private SqlLifecycleManager lifecycleManager; + + @Before + public void setup() + { + lifecycleManager = new SqlLifecycleManager(); + } + + @Test + public void testAddNonAuthorizedLifeCycle() + { + SqlLifecycle lifecycle = mockLifecycle(State.INITIALIZED); + Assert.assertThrows(AssertionError.class, () -> lifecycleManager.add("sqlId", lifecycle)); + } + + @Test + public void testAddAuthorizedLifecycle() + { + final String sqlId = "sqlId"; + SqlLifecycle lifecycle = mockLifecycle(State.AUTHORIZED); + lifecycleManager.add(sqlId, lifecycle); + Assert.assertEquals(ImmutableList.of(lifecycle), lifecycleManager.getAll(sqlId)); + } + + @Test + public void testRemoveValidLifecycle() + { + final String sqlId = "sqlId"; + SqlLifecycle lifecycle = mockLifecycle(State.AUTHORIZED); + lifecycleManager.add(sqlId, lifecycle); + Assert.assertEquals(ImmutableList.of(lifecycle), lifecycleManager.getAll(sqlId)); + lifecycleManager.remove(sqlId, lifecycle); + Assert.assertEquals(ImmutableList.of(), lifecycleManager.getAll(sqlId)); + } + + @Test + public void testRemoveInvalidSqlQueryId() + { + final String sqlId = "sqlId"; + SqlLifecycle lifecycle = mockLifecycle(State.AUTHORIZED); + lifecycleManager.add(sqlId, lifecycle); + Assert.assertEquals(ImmutableList.of(lifecycle), lifecycleManager.getAll(sqlId)); + lifecycleManager.remove("invalid", lifecycle); + Assert.assertEquals(ImmutableList.of(lifecycle), lifecycleManager.getAll(sqlId)); + } + + @Test + public void testRemoveValidSqlQueryIdDifferntLifecycleObject() + { + final String sqlId = "sqlId"; + SqlLifecycle lifecycle = mockLifecycle(State.AUTHORIZED); + lifecycleManager.add(sqlId, lifecycle); + Assert.assertEquals(ImmutableList.of(lifecycle), lifecycleManager.getAll(sqlId)); + lifecycleManager.remove(sqlId, mockLifecycle(State.AUTHORIZED)); + Assert.assertEquals(ImmutableList.of(lifecycle), lifecycleManager.getAll(sqlId)); + } + + @Test + public void testRemoveAllValidSqlQueryIdSubsetOfLifecycles() + { + final String sqlId = "sqlId"; + final List lifecycles = ImmutableList.of( + mockLifecycle(State.AUTHORIZED), + mockLifecycle(State.AUTHORIZED), + mockLifecycle(State.AUTHORIZED) + ); + lifecycles.forEach(lifecycle -> lifecycleManager.add(sqlId, lifecycle)); + Assert.assertEquals(lifecycles, lifecycleManager.getAll(sqlId)); + lifecycleManager.removeAll(sqlId, ImmutableList.of(lifecycles.get(0), lifecycles.get(1))); + Assert.assertEquals(ImmutableList.of(lifecycles.get(2)), lifecycleManager.getAll(sqlId)); + } + + @Test + public void testRemoveAllInvalidSqlQueryId() + { + final String sqlId = "sqlId"; + final List lifecycles = ImmutableList.of( + mockLifecycle(State.AUTHORIZED), + mockLifecycle(State.AUTHORIZED), + mockLifecycle(State.AUTHORIZED) + ); + lifecycles.forEach(lifecycle -> lifecycleManager.add(sqlId, lifecycle)); + Assert.assertEquals(lifecycles, lifecycleManager.getAll(sqlId)); + lifecycleManager.removeAll("invalid", ImmutableList.of(lifecycles.get(0), lifecycles.get(1))); + Assert.assertEquals(lifecycles, lifecycleManager.getAll(sqlId)); + } + + @Test + public void testGetAllReturnsListCopy() + { + final String sqlId = "sqlId"; + final List lifecycles = ImmutableList.of( + mockLifecycle(State.AUTHORIZED), + mockLifecycle(State.AUTHORIZED), + mockLifecycle(State.AUTHORIZED) + ); + lifecycles.forEach(lifecycle -> lifecycleManager.add(sqlId, lifecycle)); + final List lifecyclesFromGetAll = lifecycleManager.getAll(sqlId); + lifecycleManager.removeAll(sqlId, lifecyclesFromGetAll); + Assert.assertEquals(lifecycles, lifecyclesFromGetAll); + Assert.assertTrue(lifecycleManager.getAll(sqlId).isEmpty()); + } + + private static SqlLifecycle mockLifecycle(State state) + { + SqlLifecycle lifecycle = Mockito.mock(SqlLifecycle.class); + Mockito.when(lifecycle.getState()).thenReturn(state); + return lifecycle; + } +} diff --git a/sql/src/test/java/org/apache/druid/sql/SqlLifecycleTest.java b/sql/src/test/java/org/apache/druid/sql/SqlLifecycleTest.java index 6fd383bb247..60659120013 100644 --- a/sql/src/test/java/org/apache/druid/sql/SqlLifecycleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/SqlLifecycleTest.java @@ -30,6 +30,7 @@ import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.java.util.emitter.service.ServiceEventBuilder; import org.apache.druid.query.QueryContexts; +import org.apache.druid.server.QueryStackTests; import org.apache.druid.server.log.RequestLogger; import org.apache.druid.server.security.Access; import org.apache.druid.server.security.AuthConfig; @@ -51,6 +52,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; public class SqlLifecycleTest { @@ -65,7 +67,12 @@ public class SqlLifecycleTest this.plannerFactory = EasyMock.createMock(PlannerFactory.class); this.serviceEmitter = EasyMock.createMock(ServiceEmitter.class); this.requestLogger = EasyMock.createMock(RequestLogger.class); - this.sqlLifecycleFactory = new SqlLifecycleFactory(plannerFactory, serviceEmitter, requestLogger); + this.sqlLifecycleFactory = new SqlLifecycleFactory( + plannerFactory, + serviceEmitter, + requestLogger, + QueryStackTests.DEFAULT_NOOP_SCHEDULER + ); } @Test @@ -142,8 +149,8 @@ public class SqlLifecycleTest mockPlanner.close(); EasyMock.expectLastCall(); EasyMock.replay(plannerFactory, serviceEmitter, requestLogger, mockPlanner, mockPlannerContext, mockPrepareResult, mockPlanResult); - PlannerContext context = lifecycle.plan(); - Assert.assertEquals(mockPlannerContext, context); + lifecycle.plan(); + Assert.assertEquals(mockPlannerContext, lifecycle.getPlannerContext()); Assert.assertEquals(SqlLifecycle.State.PLANNED, lifecycle.getState()); EasyMock.verify(plannerFactory, serviceEmitter, requestLogger, mockPlanner, mockPlannerContext, mockPrepareResult, mockPlanResult); EasyMock.reset(plannerFactory, serviceEmitter, requestLogger, mockPlanner, mockPlannerContext, mockPrepareResult, mockPlanResult); @@ -158,7 +165,8 @@ public class SqlLifecycleTest // test emit EasyMock.expect(mockPlannerContext.getSqlQueryId()).andReturn("id").once(); - EasyMock.expect(mockPlannerContext.getNativeQueryIds()).andReturn(ImmutableList.of("id")).times(2); + CopyOnWriteArrayList nativeQueryIds = new CopyOnWriteArrayList<>(ImmutableList.of("id")); + EasyMock.expect(mockPlannerContext.getNativeQueryIds()).andReturn(nativeQueryIds).times(2); EasyMock.expect(mockPlannerContext.getAuthenticationResult()).andReturn(CalciteTests.REGULAR_USER_AUTH_RESULT).once(); serviceEmitter.emit(EasyMock.anyObject(ServiceEventBuilder.class)); @@ -169,7 +177,7 @@ public class SqlLifecycleTest EasyMock.expectLastCall(); EasyMock.replay(plannerFactory, serviceEmitter, requestLogger, mockPlanner, mockPlannerContext, mockPrepareResult, mockPlanResult); - lifecycle.emitLogsAndMetrics(null, null, 10); + lifecycle.finalizeStateAndEmitLogsAndMetrics(null, null, 10); Assert.assertEquals(SqlLifecycle.State.DONE, lifecycle.getState()); EasyMock.verify(plannerFactory, serviceEmitter, requestLogger, mockPlanner, mockPlannerContext, mockPrepareResult, mockPlanResult); EasyMock.reset(plannerFactory, serviceEmitter, requestLogger, mockPlanner, mockPlannerContext, mockPrepareResult, mockPlanResult); @@ -244,8 +252,8 @@ public class SqlLifecycleTest mockPlanner.close(); EasyMock.expectLastCall(); EasyMock.replay(plannerFactory, serviceEmitter, requestLogger, mockPlanner, mockPlannerContext, mockPrepareResult, mockPlanResult); - PlannerContext context = lifecycle.plan(); - Assert.assertEquals(mockPlannerContext, context); + lifecycle.plan(); + Assert.assertEquals(mockPlannerContext, lifecycle.getPlannerContext()); Assert.assertEquals(SqlLifecycle.State.PLANNED, lifecycle.getState()); EasyMock.verify(plannerFactory, serviceEmitter, requestLogger, mockPlanner, mockPlannerContext, mockPrepareResult, mockPlanResult); EasyMock.reset(plannerFactory, serviceEmitter, requestLogger, mockPlanner, mockPlannerContext, mockPrepareResult, mockPlanResult); @@ -260,7 +268,8 @@ public class SqlLifecycleTest // test emit EasyMock.expect(mockPlannerContext.getSqlQueryId()).andReturn("id").once(); - EasyMock.expect(mockPlannerContext.getNativeQueryIds()).andReturn(ImmutableList.of("id")).times(2); + CopyOnWriteArrayList nativeQueryIds = new CopyOnWriteArrayList<>(ImmutableList.of("id")); + EasyMock.expect(mockPlannerContext.getNativeQueryIds()).andReturn(nativeQueryIds).times(2); EasyMock.expect(mockPlannerContext.getAuthenticationResult()).andReturn(CalciteTests.REGULAR_USER_AUTH_RESULT).once(); serviceEmitter.emit(EasyMock.anyObject(ServiceEventBuilder.class)); @@ -271,7 +280,7 @@ public class SqlLifecycleTest EasyMock.expectLastCall(); EasyMock.replay(plannerFactory, serviceEmitter, requestLogger, mockPlanner, mockPlannerContext, mockPrepareResult, mockPlanResult); - lifecycle.emitLogsAndMetrics(null, null, 10); + lifecycle.finalizeStateAndEmitLogsAndMetrics(null, null, 10); Assert.assertEquals(SqlLifecycle.State.DONE, lifecycle.getState()); EasyMock.verify(plannerFactory, serviceEmitter, requestLogger, mockPlanner, mockPlannerContext, mockPrepareResult, mockPlanResult); EasyMock.reset(plannerFactory, serviceEmitter, requestLogger, mockPlanner, mockPlannerContext, mockPrepareResult, mockPlanResult); diff --git a/sql/src/test/java/org/apache/druid/sql/SqlRowTransformerTest.java b/sql/src/test/java/org/apache/druid/sql/SqlRowTransformerTest.java new file mode 100644 index 00000000000..1ca9db2a3c7 --- /dev/null +++ b/sql/src/test/java/org/apache/druid/sql/SqlRowTransformerTest.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.sql; + +import com.google.common.collect.ImmutableList; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.sql.type.SqlTypeFactoryImpl; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.calcite.planner.DruidTypeSystem; +import org.apache.druid.sql.calcite.util.CalciteTestBase; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; +import org.joda.time.format.ISODateTimeFormat; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class SqlRowTransformerTest extends CalciteTestBase +{ + private RelDataType rowType; + + @Before + public void setup() + { + final RelDataTypeFactory typeFactory = new SqlTypeFactoryImpl(DruidTypeSystem.INSTANCE); + rowType = typeFactory.createStructType( + ImmutableList.of( + typeFactory.createSqlType(SqlTypeName.TIMESTAMP), + typeFactory.createSqlType(SqlTypeName.DATE), + typeFactory.createSqlType(SqlTypeName.VARCHAR), + typeFactory.createSqlType(SqlTypeName.VARCHAR) + ), + ImmutableList.of( + "timestamp_col", + "date_col", + "string_col", + "null" + ) + ); + } + + @Test + public void testTransformUTC() + { + SqlRowTransformer transformer = new SqlRowTransformer( + DateTimeZone.UTC, + rowType + ); + DateTime timestamp = DateTimes.of("2021-08-01T12:00:00"); + DateTime date = DateTimes.of("2021-01-01"); + Object[] expectedRow = new Object[]{ + ISODateTimeFormat.dateTime().print(timestamp), + ISODateTimeFormat.dateTime().print(date), + "string", + null + }; + Object[] row = new Object[]{ + Calcites.jodaToCalciteTimestamp(timestamp, DateTimeZone.UTC), + Calcites.jodaToCalciteDate(date, DateTimeZone.UTC), + expectedRow[2], + null + }; + Assert.assertArrayEquals( + expectedRow, + IntStream.range(0, expectedRow.length).mapToObj(i -> transformer.transform(row, i)).toArray() + ); + } + + @Test + public void testTransformNonUTC() + { + DateTimeZone timeZone = DateTimes.inferTzFromString("Asia/Seoul"); + SqlRowTransformer transformer = new SqlRowTransformer( + timeZone, + rowType + ); + DateTime timestamp = new DateTime("2021-08-01T12:00:00", timeZone); + DateTime date = new DateTime("2021-01-01", timeZone); + Object[] expectedRow = new Object[]{ + ISODateTimeFormat.dateTime().withZone(timeZone).print(timestamp), + ISODateTimeFormat.dateTime().withZone(timeZone).print(date), + "string", + null + }; + Object[] row = new Object[]{ + Calcites.jodaToCalciteTimestamp(timestamp, timeZone), + Calcites.jodaToCalciteDate(date, timeZone), + expectedRow[2], + null + }; + Assert.assertArrayEquals( + expectedRow, + IntStream.range(0, expectedRow.length).mapToObj(i -> transformer.transform(row, i)).toArray() + ); + } + + @Test + public void testGetFieldList() + { + SqlRowTransformer transformer = new SqlRowTransformer( + DateTimeZone.UTC, + rowType + ); + + Assert.assertEquals( + rowType.getFieldList().stream().map(RelDataTypeField::getName).collect(Collectors.toList()), + transformer.getFieldList() + ); + } +} diff --git a/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java b/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java index ef72cbd7e68..3278bb6a80f 100644 --- a/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java +++ b/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java @@ -39,6 +39,7 @@ import org.apache.calcite.avatica.server.AbstractAvaticaHandler; import org.apache.calcite.schema.SchemaPlus; import org.apache.druid.common.config.NullHandling; import org.apache.druid.guice.GuiceInjectors; +import org.apache.druid.guice.LazySingleton; import org.apache.druid.initialization.Initialization; import org.apache.druid.java.util.common.DateTimes; import org.apache.druid.java.util.common.Pair; @@ -48,6 +49,8 @@ import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.server.QueryLifecycleFactory; +import org.apache.druid.server.QueryScheduler; +import org.apache.druid.server.QuerySchedulerProvider; import org.apache.druid.server.QueryStackTests; import org.apache.druid.server.RequestLogLine; import org.apache.druid.server.log.RequestLogger; @@ -195,6 +198,10 @@ public abstract class DruidAvaticaHandlerTest extends CalciteTestBase .toInstance(CalciteTests.DRUID_SCHEMA_NAME); binder.bind(AvaticaServerConfig.class).toInstance(AVATICA_CONFIG); binder.bind(ServiceEmitter.class).to(NoopServiceEmitter.class); + binder.bind(QuerySchedulerProvider.class).in(LazySingleton.class); + binder.bind(QueryScheduler.class) + .toProvider(QuerySchedulerProvider.class) + .in(LazySingleton.class); } } ) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTests.java b/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTests.java index 569e831240c..4e394825ee4 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTests.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/util/CalciteTests.java @@ -784,7 +784,8 @@ public class CalciteTests return new SqlLifecycleFactory( plannerFactory, new ServiceEmitter("dummy", "dummy", new NoopEmitter()), - new NoopRequestLogger() + new NoopRequestLogger(), + QueryStackTests.DEFAULT_NOOP_SCHEDULER ); } diff --git a/sql/src/test/java/org/apache/druid/sql/guice/SqlModuleTest.java b/sql/src/test/java/org/apache/druid/sql/guice/SqlModuleTest.java index b5c4aa740ae..e80ae48cb86 100644 --- a/sql/src/test/java/org/apache/druid/sql/guice/SqlModuleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/guice/SqlModuleTest.java @@ -51,6 +51,8 @@ import org.apache.druid.query.QueryToolChestWarehouse; import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider; import org.apache.druid.segment.join.JoinableFactory; import org.apache.druid.segment.loading.SegmentLoader; +import org.apache.druid.server.QueryScheduler; +import org.apache.druid.server.QuerySchedulerProvider; import org.apache.druid.server.log.NoopRequestLogger; import org.apache.druid.server.log.RequestLogger; import org.apache.druid.server.security.AuthorizerMapper; @@ -192,7 +194,10 @@ public class SqlModuleTest binder.bind(LookupExtractorFactoryContainerProvider.class).toInstance(lookupExtractorFactoryContainerProvider); binder.bind(JoinableFactory.class).toInstance(joinableFactory); binder.bind(SegmentLoader.class).toInstance(segmentLoader); - + binder.bind(QuerySchedulerProvider.class).in(LazySingleton.class); + binder.bind(QueryScheduler.class) + .toProvider(QuerySchedulerProvider.class) + .in(LazySingleton.class); }, new SqlModule(props), new TestViewManagerModule() diff --git a/sql/src/test/java/org/apache/druid/sql/http/SqlHttpModuleTest.java b/sql/src/test/java/org/apache/druid/sql/http/SqlHttpModuleTest.java index e0abc585842..c035f7c382c 100644 --- a/sql/src/test/java/org/apache/druid/sql/http/SqlHttpModuleTest.java +++ b/sql/src/test/java/org/apache/druid/sql/http/SqlHttpModuleTest.java @@ -24,8 +24,11 @@ import com.google.inject.Guice; import com.google.inject.Injector; import com.google.inject.Key; import com.google.inject.TypeLiteral; +import org.apache.druid.guice.DruidGuiceExtensions; +import org.apache.druid.guice.LifecycleModule; import org.apache.druid.guice.annotations.JSR311Resource; import org.apache.druid.guice.annotations.Json; +import org.apache.druid.server.security.AuthorizerMapper; import org.apache.druid.sql.SqlLifecycleFactory; import org.easymock.EasyMockRunner; import org.easymock.Mock; @@ -34,6 +37,7 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; +import java.util.Collections; import java.util.Set; @RunWith(EasyMockRunner.class) @@ -52,11 +56,15 @@ public class SqlHttpModuleTest { target = new SqlHttpModule(); injector = Guice.createInjector( + new LifecycleModule(), + new DruidGuiceExtensions(), binder -> { binder.bind(ObjectMapper.class).annotatedWith(Json.class).toInstance(jsonMpper); binder.bind(SqlLifecycleFactory.class).toInstance(sqlLifecycleFactory); + binder.bind(AuthorizerMapper.class).toInstance(new AuthorizerMapper(Collections.emptyMap())); }, - target); + target + ); } @Test diff --git a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java index dd0cee1ca8b..02f596adbda 100644 --- a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java +++ b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java @@ -29,15 +29,19 @@ import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; import org.apache.calcite.avatica.SqlType; import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.tools.RelConversionException; import org.apache.druid.common.config.NullHandling; +import org.apache.druid.common.guava.SettableSupplier; import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.NonnullPair; import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.guava.LazySequence; import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.Query; import org.apache.druid.query.QueryCapacityExceededException; @@ -51,13 +55,16 @@ import org.apache.druid.query.ResourceLimitExceededException; import org.apache.druid.server.QueryScheduler; import org.apache.druid.server.QueryStackTests; import org.apache.druid.server.initialization.ServerConfig; +import org.apache.druid.server.log.RequestLogger; import org.apache.druid.server.log.TestRequestLogger; import org.apache.druid.server.metrics.NoopServiceEmitter; import org.apache.druid.server.scheduling.HiLoQueryLaningStrategy; import org.apache.druid.server.scheduling.ManualQueryPrioritizationStrategy; import org.apache.druid.server.security.AuthConfig; import org.apache.druid.server.security.ForbiddenException; +import org.apache.druid.sql.SqlLifecycle; import org.apache.druid.sql.SqlLifecycleFactory; +import org.apache.druid.sql.SqlLifecycleManager; import org.apache.druid.sql.SqlPlanningException.PlanningError; import org.apache.druid.sql.calcite.planner.DruidOperatorTable; import org.apache.druid.sql.calcite.planner.PlannerConfig; @@ -79,6 +86,7 @@ import org.junit.rules.TemporaryFolder; import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.Response; +import javax.ws.rs.core.Response.Status; import javax.ws.rs.core.StreamingOutput; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -87,7 +95,9 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.stream.Collectors; @@ -107,6 +117,12 @@ public class SqlResourceTest extends CalciteTestBase private SqlResource resource; private HttpServletRequest req; private ListeningExecutorService executorService; + private SqlLifecycleManager lifecycleManager; + + private CountDownLatch lifecycleAddLatch; + private final SettableSupplier> validateAndAuthorizeLatchSupplier = new SettableSupplier<>(); + private final SettableSupplier> planLatchSupplier = new SettableSupplier<>(); + private final SettableSupplier> executeLatchSupplier = new SettableSupplier<>(); private boolean sleep = false; @@ -204,13 +220,45 @@ public class SqlResourceTest extends CalciteTestBase CalciteTests.DRUID_SCHEMA_NAME ); + lifecycleManager = new SqlLifecycleManager() + { + @Override + public void add(String sqlQueryId, SqlLifecycle lifecycle) + { + super.add(sqlQueryId, lifecycle); + if (lifecycleAddLatch != null) { + lifecycleAddLatch.countDown(); + } + } + }; + final ServiceEmitter emitter = new NoopServiceEmitter(); resource = new SqlResource( JSON_MAPPER, + CalciteTests.TEST_AUTHORIZER_MAPPER, new SqlLifecycleFactory( plannerFactory, - new NoopServiceEmitter(), - testRequestLogger + emitter, + testRequestLogger, + scheduler ) + { + @Override + public SqlLifecycle factorize() + { + return new TestSqlLifecycle( + plannerFactory, + emitter, + testRequestLogger, + scheduler, + System.currentTimeMillis(), + System.nanoTime(), + validateAndAuthorizeLatchSupplier, + planLatchSupplier, + executeLatchSupplier + ); + } + }, + lifecycleManager ); } @@ -220,6 +268,7 @@ public class SqlResourceTest extends CalciteTestBase walker.close(); walker = null; executorService.shutdownNow(); + executorService.awaitTermination(2, TimeUnit.SECONDS); } @Test @@ -243,7 +292,7 @@ public class SqlResourceTest extends CalciteTestBase try { resource.doPost( - new SqlQuery("select count(*) from forbiddenDatasource", null, false, null, null), + createSimpleQueryWithId("id", "select count(*) from forbiddenDatasource"), testRequest ); Assert.fail("doPost did not throw ForbiddenException for an unauthorized query"); @@ -252,13 +301,14 @@ public class SqlResourceTest extends CalciteTestBase // expected } Assert.assertEquals(0, testRequestLogger.getSqlQueryLogs().size()); + Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); } @Test public void testCountStar() throws Exception { final List> rows = doPost( - new SqlQuery("SELECT COUNT(*) AS cnt, 'foo' AS TheFoo FROM druid.foo", null, false, null, null) + createSimpleQueryWithId("id", "SELECT COUNT(*) AS cnt, 'foo' AS TheFoo FROM druid.foo") ).rhs; Assert.assertEquals( @@ -268,6 +318,7 @@ public class SqlResourceTest extends CalciteTestBase rows ); checkSqlRequestLog(true); + Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); } @@ -275,7 +326,10 @@ public class SqlResourceTest extends CalciteTestBase public void testCountStarExtendedCharacters() throws Exception { final List> rows = doPost( - new SqlQuery("SELECT COUNT(*) AS cnt FROM druid.lotsocolumns WHERE dimMultivalEnumerated = 'ㅑ ㅓ ㅕ ㅗ ㅛ ㅜ ㅠ ㅡ ㅣ'", null, false, null, null) + createSimpleQueryWithId( + "id", + "SELECT COUNT(*) AS cnt FROM druid.lotsocolumns WHERE dimMultivalEnumerated = 'ㅑ ㅓ ㅕ ㅗ ㅛ ㅜ ㅠ ㅡ ㅣ'" + ) ).rhs; Assert.assertEquals( @@ -285,6 +339,7 @@ public class SqlResourceTest extends CalciteTestBase rows ); checkSqlRequestLog(true); + Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); } @Test @@ -490,7 +545,11 @@ public class SqlResourceTest extends CalciteTestBase public void testArrayLinesResultFormat() throws Exception { final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2"; - final String response = doPostRaw(new SqlQuery(query, ResultFormat.ARRAYLINES, false, null, null)).rhs; + final Pair pair = doPostRaw( + new SqlQuery(query, ResultFormat.ARRAYLINES, false, null, null) + ); + Assert.assertNull(pair.lhs); + final String response = pair.rhs; final String nullStr = NullHandling.replaceWithDefault() ? "" : null; final List lines = Splitter.on('\n').splitToList(response); @@ -531,7 +590,11 @@ public class SqlResourceTest extends CalciteTestBase public void testArrayLinesResultFormatWithHeader() throws Exception { final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2"; - final String response = doPostRaw(new SqlQuery(query, ResultFormat.ARRAYLINES, true, null, null)).rhs; + final Pair pair = doPostRaw( + new SqlQuery(query, ResultFormat.ARRAYLINES, true, null, null) + ); + Assert.assertNull(pair.lhs); + final String response = pair.rhs; final String nullStr = NullHandling.replaceWithDefault() ? "" : null; final List lines = Splitter.on('\n').splitToList(response); @@ -622,7 +685,11 @@ public class SqlResourceTest extends CalciteTestBase public void testObjectLinesResultFormat() throws Exception { final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2"; - final String response = doPostRaw(new SqlQuery(query, ResultFormat.OBJECTLINES, false, null, null)).rhs; + final Pair pair = doPostRaw( + new SqlQuery(query, ResultFormat.OBJECTLINES, false, null, null) + ); + Assert.assertNull(pair.lhs); + final String response = pair.rhs; final String nullStr = NullHandling.replaceWithDefault() ? "" : null; final Function, Map> transformer = m -> { return Maps.transformEntries( @@ -675,7 +742,11 @@ public class SqlResourceTest extends CalciteTestBase public void testCsvResultFormat() throws Exception { final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2"; - final String response = doPostRaw(new SqlQuery(query, ResultFormat.CSV, false, null, null)).rhs; + final Pair pair = doPostRaw( + new SqlQuery(query, ResultFormat.CSV, false, null, null) + ); + Assert.assertNull(pair.lhs); + final String response = pair.rhs; final List lines = Splitter.on('\n').splitToList(response); Assert.assertEquals( @@ -693,7 +764,11 @@ public class SqlResourceTest extends CalciteTestBase public void testCsvResultFormatWithHeaders() throws Exception { final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2"; - final String response = doPostRaw(new SqlQuery(query, ResultFormat.CSV, true, null, null)).rhs; + final Pair pair = doPostRaw( + new SqlQuery(query, ResultFormat.CSV, true, null, null) + ); + Assert.assertNull(pair.lhs); + final String response = pair.rhs; final List lines = Splitter.on('\n').splitToList(response); Assert.assertEquals( @@ -736,13 +811,7 @@ public class SqlResourceTest extends CalciteTestBase public void testCannotParse() throws Exception { final QueryException exception = doPost( - new SqlQuery( - "FROM druid.foo", - ResultFormat.OBJECT, - false, - null, - null - ) + createSimpleQueryWithId("id", "FROM druid.foo") ).lhs; Assert.assertNotNull(exception); @@ -750,19 +819,14 @@ public class SqlResourceTest extends CalciteTestBase Assert.assertEquals(PlanningError.SQL_PARSE_ERROR.getErrorClass(), exception.getErrorClass()); Assert.assertTrue(exception.getMessage().contains("Encountered \"FROM\" at line 1, column 1.")); checkSqlRequestLog(false); + Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); } @Test public void testCannotValidate() throws Exception { final QueryException exception = doPost( - new SqlQuery( - "SELECT dim4 FROM druid.foo", - ResultFormat.OBJECT, - false, - null, - null - ) + createSimpleQueryWithId("id", "SELECT dim4 FROM druid.foo") ).lhs; Assert.assertNotNull(exception); @@ -770,6 +834,7 @@ public class SqlResourceTest extends CalciteTestBase Assert.assertEquals(PlanningError.VALIDATION_ERROR.getErrorClass(), exception.getErrorClass()); Assert.assertTrue(exception.getMessage().contains("Column 'dim4' not found in any table")); checkSqlRequestLog(false); + Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); } @Test @@ -777,7 +842,7 @@ public class SqlResourceTest extends CalciteTestBase { // SELECT + ORDER unsupported final QueryException exception = doPost( - new SqlQuery("SELECT dim1 FROM druid.foo ORDER BY dim1", ResultFormat.OBJECT, false, null, null) + createSimpleQueryWithId("id", "SELECT dim1 FROM druid.foo ORDER BY dim1") ).lhs; Assert.assertNotNull(exception); @@ -788,6 +853,7 @@ public class SqlResourceTest extends CalciteTestBase .contains("Cannot build plan for query: SELECT dim1 FROM druid.foo ORDER BY dim1") ); checkSqlRequestLog(false); + Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); } @Test @@ -798,7 +864,7 @@ public class SqlResourceTest extends CalciteTestBase "SELECT DISTINCT dim1 FROM foo", ResultFormat.OBJECT, false, - ImmutableMap.of("maxMergingDictionarySize", 1), + ImmutableMap.of("maxMergingDictionarySize", 1, "sqlQueryId", "id"), null ) ).lhs; @@ -807,6 +873,7 @@ public class SqlResourceTest extends CalciteTestBase Assert.assertEquals(exception.getErrorCode(), ResourceLimitExceededException.ERROR_CODE); Assert.assertEquals(exception.getErrorClass(), ResourceLimitExceededException.class.getName()); checkSqlRequestLog(false); + Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); } @Test @@ -815,7 +882,7 @@ public class SqlResourceTest extends CalciteTestBase String errorMessage = "This will be support in Druid 9999"; SqlQuery badQuery = EasyMock.createMock(SqlQuery.class); EasyMock.expect(badQuery.getQuery()).andReturn("SELECT ANSWER TO LIFE"); - EasyMock.expect(badQuery.getContext()).andReturn(ImmutableMap.of()); + EasyMock.expect(badQuery.getContext()).andReturn(ImmutableMap.of("sqlQueryId", "id")); EasyMock.expect(badQuery.getParameterList()).andThrow(new QueryUnsupportedException(errorMessage)); EasyMock.replay(badQuery); final QueryException exception = doPost(badQuery).lhs; @@ -823,6 +890,7 @@ public class SqlResourceTest extends CalciteTestBase Assert.assertNotNull(exception); Assert.assertEquals(QueryUnsupportedException.ERROR_CODE, exception.getErrorCode()); Assert.assertEquals(QueryUnsupportedException.class.getName(), exception.getErrorClass()); + Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); } @Test @@ -830,6 +898,7 @@ public class SqlResourceTest extends CalciteTestBase { sleep = true; final int numQueries = 3; + final String sqlQueryId = "tooManyRequestsTest"; List>>>> futures = new ArrayList<>(numQueries); for (int i = 0; i < numQueries; i++) { @@ -840,7 +909,7 @@ public class SqlResourceTest extends CalciteTestBase "SELECT COUNT(*) AS cnt, 'foo' AS TheFoo FROM druid.foo", null, false, - ImmutableMap.of("priority", -5), + ImmutableMap.of("priority", -5, "sqlQueryId", sqlQueryId), null ), makeExpectedReq() @@ -874,12 +943,14 @@ public class SqlResourceTest extends CalciteTestBase Assert.assertEquals(2, success); Assert.assertEquals(1, limited); Assert.assertEquals(3, testRequestLogger.getSqlQueryLogs().size()); + Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty()); } @Test public void testQueryTimeoutException() throws Exception { - Map queryContext = ImmutableMap.of(QueryContexts.TIMEOUT_KEY, 1); + final String sqlQueryId = "timeoutTest"; + Map queryContext = ImmutableMap.of(QueryContexts.TIMEOUT_KEY, 1, "sqlQueryId", sqlQueryId); final QueryException timeoutException = doPost( new SqlQuery( "SELECT CAST(__time AS DATE), dim1, dim2, dim3 FROM druid.foo GROUP by __time, dim1, dim2, dim3 ORDER BY dim2 DESC", @@ -892,9 +963,97 @@ public class SqlResourceTest extends CalciteTestBase Assert.assertNotNull(timeoutException); Assert.assertEquals(timeoutException.getErrorCode(), QueryTimeoutException.ERROR_CODE); Assert.assertEquals(timeoutException.getErrorClass(), QueryTimeoutException.class.getName()); + Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty()); } + @Test + public void testCancelBetweenValidateAndPlan() throws Exception + { + final String sqlQueryId = "toCancel"; + lifecycleAddLatch = new CountDownLatch(1); + CountDownLatch validateAndAuthorizeLatch = new CountDownLatch(1); + validateAndAuthorizeLatchSupplier.set(new NonnullPair<>(validateAndAuthorizeLatch, true)); + CountDownLatch planLatch = new CountDownLatch(1); + planLatchSupplier.set(new NonnullPair<>(planLatch, false)); + Future future = executorService.submit( + () -> resource.doPost( + createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"), + makeExpectedReq() + ) + ); + Assert.assertTrue(validateAndAuthorizeLatch.await(1, TimeUnit.SECONDS)); + Assert.assertTrue(lifecycleAddLatch.await(1, TimeUnit.SECONDS)); + Response response = resource.cancelQuery(sqlQueryId, mockRequestForCancel()); + planLatch.countDown(); + Assert.assertEquals(Status.ACCEPTED.getStatusCode(), response.getStatus()); + + Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty()); + + response = future.get(); + Assert.assertEquals(Status.INTERNAL_SERVER_ERROR.getStatusCode(), response.getStatus()); + QueryException exception = JSON_MAPPER.readValue((byte[]) response.getEntity(), QueryException.class); + Assert.assertEquals( + QueryInterruptedException.QUERY_CANCELLED, + exception.getErrorCode() + ); + } + + @Test + public void testCancelBetweenPlanAndExecute() throws Exception + { + final String sqlQueryId = "toCancel"; + CountDownLatch planLatch = new CountDownLatch(1); + planLatchSupplier.set(new NonnullPair<>(planLatch, true)); + CountDownLatch execLatch = new CountDownLatch(1); + executeLatchSupplier.set(new NonnullPair<>(execLatch, false)); + Future future = executorService.submit( + () -> resource.doPost( + createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"), + makeExpectedReq() + ) + ); + Assert.assertTrue(planLatch.await(1, TimeUnit.SECONDS)); + Response response = resource.cancelQuery(sqlQueryId, mockRequestForCancel()); + execLatch.countDown(); + Assert.assertEquals(Status.ACCEPTED.getStatusCode(), response.getStatus()); + + Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty()); + + response = future.get(); + Assert.assertEquals(Status.INTERNAL_SERVER_ERROR.getStatusCode(), response.getStatus()); + QueryException exception = JSON_MAPPER.readValue((byte[]) response.getEntity(), QueryException.class); + Assert.assertEquals( + QueryInterruptedException.QUERY_CANCELLED, + exception.getErrorCode() + ); + } + + @Test + public void testCancelInvalidQuery() throws Exception + { + final String sqlQueryId = "validQuery"; + CountDownLatch planLatch = new CountDownLatch(1); + planLatchSupplier.set(new NonnullPair<>(planLatch, true)); + CountDownLatch execLatch = new CountDownLatch(1); + executeLatchSupplier.set(new NonnullPair<>(execLatch, false)); + Future future = executorService.submit( + () -> resource.doPost( + createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"), + makeExpectedReq() + ) + ); + Assert.assertTrue(planLatch.await(1, TimeUnit.SECONDS)); + Response response = resource.cancelQuery("invalidQuery", mockRequestForCancel()); + Assert.assertEquals(Status.NOT_FOUND.getStatusCode(), response.getStatus()); + + Assert.assertFalse(lifecycleManager.getAll(sqlQueryId).isEmpty()); + + execLatch.countDown(); + response = future.get(); + Assert.assertEquals(Status.OK.getStatusCode(), response.getStatus()); + } + @SuppressWarnings("unchecked") private void checkSqlRequestLog(boolean success) { @@ -913,6 +1072,10 @@ public class SqlResourceTest extends CalciteTestBase } } + private static SqlQuery createSimpleQueryWithId(String sqlQueryId, String sql) + { + return new SqlQuery(sql, null, false, ImmutableMap.of("sqlQueryId", sqlQueryId), null); + } private Pair>> doPost(final SqlQuery query) throws Exception { @@ -1000,4 +1163,115 @@ public class SqlResourceTest extends CalciteTestBase EasyMock.replay(req); return req; } + + private HttpServletRequest mockRequestForCancel() + { + HttpServletRequest req = EasyMock.createNiceMock(HttpServletRequest.class); + EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .andReturn(CalciteTests.REGULAR_USER_AUTH_RESULT) + .anyTimes(); + EasyMock.expect(req.getAttribute(AuthConfig.DRUID_ALLOW_UNSECURED_PATH)).andReturn(null).anyTimes(); + EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)) + .andReturn(null) + .anyTimes(); + req.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true); + EasyMock.expectLastCall().anyTimes(); + EasyMock.replay(req); + return req; + } + + private static class TestSqlLifecycle extends SqlLifecycle + { + private final SettableSupplier> validateAndAuthorizeLatchSupplier; + private final SettableSupplier> planLatchSupplier; + private final SettableSupplier> executeLatchSupplier; + + private TestSqlLifecycle( + PlannerFactory plannerFactory, + ServiceEmitter emitter, + RequestLogger requestLogger, + QueryScheduler queryScheduler, + long startMs, + long startNs, + SettableSupplier> validateAndAuthorizeLatchSupplier, + SettableSupplier> planLatchSupplier, + SettableSupplier> executeLatchSupplier + ) + { + super(plannerFactory, emitter, requestLogger, queryScheduler, startMs, startNs); + this.validateAndAuthorizeLatchSupplier = validateAndAuthorizeLatchSupplier; + this.planLatchSupplier = planLatchSupplier; + this.executeLatchSupplier = executeLatchSupplier; + } + + @Override + public void validateAndAuthorize(HttpServletRequest req) + { + if (validateAndAuthorizeLatchSupplier.get() != null) { + if (validateAndAuthorizeLatchSupplier.get().rhs) { + super.validateAndAuthorize(req); + validateAndAuthorizeLatchSupplier.get().lhs.countDown(); + } else { + try { + if (!validateAndAuthorizeLatchSupplier.get().lhs.await(1, TimeUnit.SECONDS)) { + throw new RuntimeException("Latch timed out"); + } + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + super.validateAndAuthorize(req); + } + } else { + super.validateAndAuthorize(req); + } + } + + @Override + public void plan() throws RelConversionException + { + if (planLatchSupplier.get() != null) { + if (planLatchSupplier.get().rhs) { + super.plan(); + planLatchSupplier.get().lhs.countDown(); + } else { + try { + if (!planLatchSupplier.get().lhs.await(1, TimeUnit.SECONDS)) { + throw new RuntimeException("Latch timed out"); + } + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + super.plan(); + } + } else { + super.plan(); + } + } + + @Override + public Sequence execute() + { + if (executeLatchSupplier.get() != null) { + if (executeLatchSupplier.get().rhs) { + Sequence sequence = super.execute(); + executeLatchSupplier.get().lhs.countDown(); + return sequence; + } else { + try { + if (!executeLatchSupplier.get().lhs.await(1, TimeUnit.SECONDS)) { + throw new RuntimeException("Latch timed out"); + } + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return super.execute(); + } + } else { + return super.execute(); + } + } + } }